コード例 #1
0
ファイル: test_tracker_stores.py プロジェクト: wavymazy/rasa
def test_tracker_store_deprecated_url_argument_from_string(domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.URLExampleTrackerStore"

    with pytest.raises(Exception):
        TrackerStore.create(store_config, domain)
コード例 #2
0
def test_sql_tracker_store_creation_with_invalid_port(domain: Domain):
    with pytest.raises(RasaException) as error:
        TrackerStore.create(
            EndpointConfig(port="$DB_PORT", type="sql"),
            domain,
        )
    assert "port '$DB_PORT' cannot be cast to integer." in str(error.value)
コード例 #3
0
    def __init__(self, domain):
        self.red = fakeredis.FakeStrictRedis()
        self.record_exp = None

        # added in redis==3.3.0, but not yet in fakeredis
        self.red.connection_pool.connection_class.health_check_interval = 0

        TrackerStore.__init__(self, domain)
コード例 #4
0
def test_load_all(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'all' strategy."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_ALL)
    result = list(loader.load())

    assert len(result) == len(list(marker_trackerstore.keys()))

    for item in result:
        assert marker_trackerstore.exists(item.sender_id)
コード例 #5
0
def test_raise_connection_exception_redis_tracker_store_creation(
        domain: Domain, monkeypatch: MonkeyPatch, endpoints_path: Text):
    store = read_endpoint_config(endpoints_path, "tracker_store")
    monkeypatch.setattr(
        rasa.core.tracker_store,
        "RedisTrackerStore",
        Mock(side_effect=ConnectionError()),
    )

    with pytest.raises(ConnectionException):
        TrackerStore.create(store, domain)
コード例 #6
0
ファイル: test_trackers.py プロジェクト: attgua/Geco
    def __init__(self, _domain: Domain) -> None:
        self.red = fakeredis.FakeStrictRedis()
        self.record_exp = None

        # added in redis==3.3.0, but not yet in fakeredis
        self.red.connection_pool.connection_class.health_check_interval = 0

        # Defined in RedisTrackerStore but needs to be added for the MockRedisTrackerStore
        self.prefix = "tracker:"

        TrackerStore.__init__(self, _domain)
コード例 #7
0
def get_or_create_tracker_store(store: TrackerStore) -> None:
    slot_key = "location"
    slot_val = "Easter Island"

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val
コード例 #8
0
ファイル: test_tracker_stores.py プロジェクト: spawn08/rasa
def test_mongo_tracker_store_raise_exception(domain: Domain, monkeypatch: MonkeyPatch):
    monkeypatch.setattr(
        rasa.core.tracker_store,
        "MongoTrackerStore",
        Mock(
            side_effect=OperationFailure("not authorized on logs to execute command.")
        ),
    )
    with pytest.raises(ConnectionException) as error:
        TrackerStore.create(
            EndpointConfig(username="******", password="******", type="mongod"),
            domain,
        )

    assert "not authorized on logs to execute command." in str(error.value)
コード例 #9
0
ファイル: bots.py プロジェクト: wireless911/chatbot-saas
    def __init__(
        self,
        path_context: PathContext,
        agent: Agent = None,
    ):
        ''' agent manager
         Args:
        model: Path to model archive.
        endpoints: Path to endpoints file.
        credentials: Path to channel credentials file.
        '''
        self.path_context = path_context
        endpoints_path = self.path_context.endpoints_file_path
        credentials_path = self.path_context.credentials_file_path

        self.agent = agent if agent else None

        # broker tracker
        # read file if have endpoints file otherwise use default setting
        if os.path.exists(endpoints_path):
            # load endpoints
            self._load_endpoints(endpoints=endpoints_path)
        else:
            # create event broker
            event_broker = FileEventBroker()  # TODO 修改这里的broker,file格式的不太适合生产
            # create tracker store
            from config.settings import REDIS_SETTING
            self.tracker_store = TrackerStore(**REDIS_SETTING,
                                              event_broker=event_broker)
コード例 #10
0
async def test_events_schema(
    monkeypatch: MonkeyPatch, default_agent: Agent, config_path: Text
):
    # this allows us to patch the printing part used in debug mode to collect the
    # reported events
    monkeypatch.setenv("RASA_TELEMETRY_DEBUG", "true")
    monkeypatch.setenv("RASA_TELEMETRY_ENABLED", "true")

    mock = Mock()
    monkeypatch.setattr(telemetry, "print_telemetry_event", mock)

    with open(TELEMETRY_EVENTS_JSON) as f:
        schemas = json.load(f)["events"]
    initial = asyncio.all_tasks()
    # Generate all known backend telemetry events, and then use events.json to
    # validate their schema.
    training_data = TrainingDataImporter.load_from_config(config_path)

    with telemetry.track_model_training(training_data, "rasa"):
        await asyncio.sleep(1)

    telemetry.track_telemetry_disabled()

    telemetry.track_data_split(0.5, "nlu")

    telemetry.track_validate_files(True)

    telemetry.track_data_convert("yaml", "nlu")

    telemetry.track_tracker_export(5, TrackerStore(domain=None), EventBroker())

    telemetry.track_interactive_learning_start(True, False)

    telemetry.track_server_start([CmdlineInput()], None, None, 42, True)

    telemetry.track_project_init("tests/")

    telemetry.track_shell_started("nlu")

    telemetry.track_rasa_x_local()

    telemetry.track_visualization()

    telemetry.track_core_model_test(5, True, default_agent)

    telemetry.track_nlu_model_test(TrainingData())

    pending = asyncio.all_tasks() - initial
    await asyncio.gather(*pending)

    assert mock.call_count == 15

    for args, _ in mock.call_args_list:
        event = args[0]
        # `metrics_id` automatically gets added to all event but is
        # not part of the schema so we need to remove it before validation
        del event["properties"]["metrics_id"]
        jsonschema.validate(
            instance=event["properties"], schema=schemas[event["event"]]
        )
コード例 #11
0
async def load_agent_on_start(
    model_path: Text,
    endpoints: AvailableEndpoints,
    remote_storage: Optional[Text],
    app: Sanic,
    loop: Text,
):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    from rasa.core import broker

    try:
        unpacked_model_context = get_model(model_path)
        if unpacked_model_context:
            with unpacked_model_context as unpacked_model:
                _, nlu_model = get_model_subdirectories(unpacked_model)
                _interpreter = NaturalLanguageInterpreter.create(
                    nlu_model, endpoints.nlu
                )
        else:
            raise RuntimeError("No model found at '{}'.".format(model_path))

    except Exception:
        logger.debug("Could not load interpreter from '{}'.".format(model_path))
        _interpreter = None

    _broker = broker.from_endpoint_config(endpoints.event_broker)
    _tracker_store = TrackerStore.find_tracker_store(
        None, endpoints.tracker_store, _broker
    )

    model_server = endpoints.model if endpoints and endpoints.model else None

    app.agent = await load_agent(
        model_path,
        model_server=model_server,
        remote_storage=remote_storage,
        interpreter=_interpreter,
        generator=endpoints.nlg,
        tracker_store=_tracker_store,
        action_endpoint=endpoints.action,
    )

    if not app.agent:
        logger.warning(
            "Agent could not be loaded with the provided configuration. "
            "Load default agent without any model."
        )
        app.agent = Agent(
            interpreter=_interpreter,
            generator=endpoints.nlg,
            tracker_store=_tracker_store,
            action_endpoint=endpoints.action,
            model_server=model_server,
            remote_storage=remote_storage,
        )

    return app.agent
コード例 #12
0
ファイル: evaluate.py プロジェクト: spawn08/rasa
def _create_tracker_loader(
    endpoint_config: Text,
    strategy: Text,
    domain: Domain,
    count: Optional[int],
    seed: Optional[int],
) -> MarkerTrackerLoader:
    """Create a tracker loader against the configured tracker store.

    Args:
        endpoint_config: Path to the endpoint configuration defining the tracker
                         store to use.
        strategy: Strategy to use when selecting trackers to extract from.
        domain: The domain to use when connecting to the tracker store.
        count: (Optional) Number of trackers to extract from (for any strategy
               except 'all').
        seed: (Optional) The seed to initialise the random number generator for
              use with the 'sample_n' strategy.

    Returns:
        A MarkerTrackerLoader object configured with the specified strategy against
        the configured tracker store.
    """
    endpoints = AvailableEndpoints.read_endpoints(endpoint_config)
    tracker_store = TrackerStore.create(endpoints.tracker_store, domain=domain)
    return MarkerTrackerLoader(tracker_store, strategy, count, seed)
コード例 #13
0
ファイル: run.py プロジェクト: ravishankr/rasa
async def load_agent_on_start(
    model_path: Text,
    endpoints: AvailableEndpoints,
    remote_storage: Optional[Text],
    app: Sanic,
    loop: AbstractEventLoop,
):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""

    # noinspection PyBroadException
    try:
        with model.get_model(model_path) as unpacked_model:
            _, nlu_model = model.get_model_subdirectories(unpacked_model)
            _interpreter = NaturalLanguageInterpreter.create(endpoints.nlu or nlu_model)
    except Exception:
        logger.debug(f"Could not load interpreter from '{model_path}'.")
        _interpreter = None

    _broker = EventBroker.create(endpoints.event_broker)
    _tracker_store = TrackerStore.create(endpoints.tracker_store, event_broker=_broker)
    _lock_store = LockStore.create(endpoints.lock_store)

    model_server = endpoints.model if endpoints and endpoints.model else None

    try:
        app.agent = await agent.load_agent(
            model_path,
            model_server=model_server,
            remote_storage=remote_storage,
            interpreter=_interpreter,
            generator=endpoints.nlg,
            tracker_store=_tracker_store,
            lock_store=_lock_store,
            action_endpoint=endpoints.action,
        )
    except Exception as e:
        rasa.shared.utils.io.raise_warning(
            f"The model at '{model_path}' could not be loaded. " f"Error: {e}"
        )
        app.agent = None

    if not app.agent:
        rasa.shared.utils.io.raise_warning(
            "Agent could not be loaded with the provided configuration. "
            "Load default agent without any model."
        )
        app.agent = Agent(
            interpreter=_interpreter,
            generator=endpoints.nlg,
            tracker_store=_tracker_store,
            action_endpoint=endpoints.action,
            model_server=model_server,
            remote_storage=remote_storage,
        )

    logger.info("Rasa server is up and running.")
    return app.agent
コード例 #14
0
ファイル: agent.py プロジェクト: ducminh-phan/rasa
def create_agent(model: Text, endpoints: Text = None) -> "Agent":
    """Create an agent instance based on a stored model.

    Args:
        model: file path to the stored model
        endpoints: file path to the used endpoint configuration
    """
    from rasa.core.tracker_store import TrackerStore
    from rasa.core.utils import AvailableEndpoints
    from rasa.core.brokers.broker import EventBroker
    import rasa.utils.common

    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    _broker = rasa.utils.common.run_in_loop(EventBroker.create(_endpoints.event_broker))
    _tracker_store = TrackerStore.create(_endpoints.tracker_store, event_broker=_broker)
    _lock_store = LockStore.create(_endpoints.lock_store)

    return Agent.load(
        model,
        generator=_endpoints.nlg,
        tracker_store=_tracker_store,
        lock_store=_lock_store,
        action_endpoint=_endpoints.action,
    )
コード例 #15
0
ファイル: test_tracker_stores.py プロジェクト: vivihuang/rasa
def test_tracker_store_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, ExampleTrackerStore)
コード例 #16
0
def test_exception_tracker_store_from_endpoint_config(
        default_domain: Domain, monkeypatch: MonkeyPatch):
    """Check if tracker store properly handles exceptions.

    If we can not create a tracker store by instantiating the
    expected type (e.g. due to an exception) we should fallback to
    the default `InMemoryTrackerStore`."""

    store = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    mock = Mock(side_effect=Exception("test exception"))
    monkeypatch.setattr(rasa.core.tracker_store, "RedisTrackerStore", mock)

    with pytest.raises(Exception) as e:
        TrackerStore.create(store, default_domain)

    assert "test exception" in str(e.value)
コード例 #17
0
async def load_agent_on_start(core_model, endpoints, nlu_model, app, loop):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    from rasa.core import broker
    from rasa.core.agent import Agent

    _interpreter = NaturalLanguageInterpreter.create(nlu_model, endpoints.nlu)
    _broker = broker.from_endpoint_config(endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     endpoints.tracker_store,
                                                     _broker)

    if endpoints and endpoints.model:
        from rasa.core import agent

        app.agent = Agent(interpreter=_interpreter,
                          generator=endpoints.nlg,
                          tracker_store=_tracker_store,
                          action_endpoint=endpoints.action)

        await agent.load_from_server(app.agent, model_server=endpoints.model)
    else:
        app.agent = Agent.load(core_model,
                               interpreter=_interpreter,
                               generator=endpoints.nlg,
                               tracker_store=_tracker_store,
                               action_endpoint=endpoints.action)

    return app.agent
コード例 #18
0
def create_agent(model: Text, endpoints: Text = None) -> 'Agent':
    from rasa.core.interpreter import RasaNLUInterpreter
    from rasa.core.tracker_store import TrackerStore
    from rasa.core import broker
    from rasa.core.utils import AvailableEndpoints

    core_path, nlu_path = get_model_subdirectories(model)
    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    _interpreter = None
    if os.path.exists(nlu_path):
        _interpreter = RasaNLUInterpreter(model_directory=nlu_path)
    else:
        _interpreter = None
        logging.info("No NLU model found. Running without NLU.")

    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    return Agent.load(core_path,
                      generator=_endpoints.nlg,
                      tracker_store=_tracker_store,
                      action_endpoint=_endpoints.action)
コード例 #19
0
async def load_agent_on_start(
    model_path: Text,
    endpoints: AvailableEndpoints,
    remote_storage: Optional[Text],
    app: Sanic,
    loop: Text,
):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    import rasa.core.brokers.utils as broker_utils

    # noinspection PyBroadException
    # bf mod
    try:
        with model.get_model(model_path) as unpacked_model:
            _, nlu_models = model.get_model_subdirectories(unpacked_model)
            _interpreters = {}
            for lang, nlu_model_path in nlu_models.items():
                _interpreters[lang] = NaturalLanguageInterpreter.create(
                    nlu_model_path, endpoints.nlu)
    except Exception:
        logger.debug(f"Could not load interpreter from '{model_path}'.")
        _interpreters = {}
    # /bf mod

    _broker = broker_utils.from_endpoint_config(endpoints.event_broker)
    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     endpoints.tracker_store,
                                                     _broker)
    _lock_store = LockStore.find_lock_store(endpoints.lock_store)

    model_server = endpoints.model if endpoints and endpoints.model else None

    app.agent = await agent.load_agent(
        model_path,
        model_server=model_server,
        remote_storage=remote_storage,
        interpreters=_interpreters,
        generator=endpoints.nlg,
        tracker_store=_tracker_store,
        lock_store=_lock_store,
        action_endpoint=endpoints.action,
    )

    if not app.agent:
        warnings.warn(
            "Agent could not be loaded with the provided configuration. "
            "Load default agent without any model.")
        app.agent = Agent(
            interpreters=_interpreters,
            generator=endpoints.nlg,
            tracker_store=_tracker_store,
            action_endpoint=endpoints.action,
            model_server=model_server,
            remote_storage=remote_storage,
        )

    return app.agent
コード例 #20
0
ファイル: test_tracker_stores.py プロジェクト: vivihuang/rasa
def test_tracker_store_from_invalid_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
コード例 #21
0
def test_create_non_async_tracker_store(domain: Domain):
    endpoint_config = EndpointConfig(
        type="tests.core.test_tracker_stores.NonAsyncTrackerStore"
    )
    with pytest.warns(FutureWarning):
        tracker_store = TrackerStore.create(endpoint_config)
    assert isinstance(tracker_store, AwaitableTrackerStore)
    assert isinstance(tracker_store._tracker_store, NonAsyncTrackerStore)
コード例 #22
0
ファイル: test_tracker_stores.py プロジェクト: vivihuang/rasa
def test_tracker_store_from_invalid_module(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
コード例 #23
0
ファイル: test_tracker_stores.py プロジェクト: wavymazy/rasa
def test_tracker_store_from_invalid_module(domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    with pytest.warns(UserWarning):
        tracker_store = TrackerStore.create(store_config, domain)

    assert isinstance(tracker_store, InMemoryTrackerStore)
コード例 #24
0
def test_tracker_store_from_invalid_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    with pytest.warns(UserWarning):
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert isinstance(tracker_store, InMemoryTrackerStore)
コード例 #25
0
def test_tracker_store_deprecated_url_argument_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.URLExampleTrackerStore"

    with pytest.warns(DeprecationWarning):
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert isinstance(tracker_store, URLExampleTrackerStore)
コード例 #26
0
ファイル: bots.py プロジェクト: wireless911/chatbot-saas
 def _load_endpoints(self, endpoints: Optional[Text] = None):
     """加载enpoints文件"""
     endpoints = AvailableEndpoints.read_endpoints(endpoints)
     broker = EventBroker.create(endpoints.event_broker)
     self.tracker_store = TrackerStore.create(endpoints.tracker_store,
                                              event_broker=broker)
     self.generator = endpoints.nlg
     self.action_endpoint = endpoints.action
     self.lock_store = LockStore.create(endpoints.lock_store)
コード例 #27
0
def test_load_first_n(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'first_n' strategy."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_FIRST_N, 3)
    result = list(loader.load())

    assert len(result) == 3

    for item in result:
        assert marker_trackerstore.exists(item.sender_id)
コード例 #28
0
def test_find_tracker_store(default_domain: Domain, monkeypatch: MonkeyPatch):
    store = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    mock = Mock(side_effect=Exception("ignore this"))
    monkeypatch.setattr(rasa.core.tracker_store, "RedisTrackerStore", mock)

    assert isinstance(
        InMemoryTrackerStore(domain),
        type(TrackerStore.find_tracker_store(default_domain, store)),
    )
コード例 #29
0
def test_tracker_store_with_host_argument_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.HostExampleTrackerStore"

    with pytest.warns(None) as record:
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert len(record) == 0

    assert isinstance(tracker_store, HostExampleTrackerStore)
コード例 #30
0
def test_create_tracker_store_from_endpoint_config(default_domain: Domain):
    store = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    tracker_store = RedisTrackerStore(
        domain=default_domain,
        host="localhost",
        port=6379,
        db=0,
        password="******",
        record_exp=3000,
    )

    assert isinstance(tracker_store, type(TrackerStore.create(store, default_domain)))