コード例 #1
0
    def test_skip_predictions_to_prevent_loop(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
        should_skip: bool,
        tmp_path: Path,
    ):
        """Skips predictions to prevent loop."""
        loaded_policy = self.persist_and_load_policy(trained_policy,
                                                     model_storage, resource,
                                                     execution_context)
        precomputations = None
        tracker = DialogueStateTracker(sender_id="init",
                                       slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)
        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations)

        assert ("Skipping predictions for UnexpecTEDIntentPolicy"
                in caplog.text) == should_skip

        if should_skip:
            assert prediction.probabilities == loaded_policy._default_predictions(
                default_domain)
コード例 #2
0
ファイル: test_utils.py プロジェクト: souvikg10/rasa_nlu
def test_validate_with_none_if_default_is_valid(caplog: LogCaptureFixture):
    tempdir = tempfile.mkdtemp()

    with caplog.at_level(logging.WARNING, rasa.cli.utils.logger.name):
        assert get_validated_path(None, "out", tempdir) == tempdir

    assert caplog.records == []
コード例 #3
0
    def test_skip_predictions_if_new_intent(
        self,
        trained_policy: UnexpecTEDIntentPolicy,
        model_storage: ModelStorage,
        resource: Resource,
        execution_context: ExecutionContext,
        default_domain: Domain,
        caplog: LogCaptureFixture,
        tracker_events: List[Event],
    ):
        """Skips predictions if there's a new intent created."""
        loaded_policy = self.persist_and_load_policy(trained_policy,
                                                     model_storage, resource,
                                                     execution_context)
        tracker = DialogueStateTracker(sender_id="init",
                                       slots=default_domain.slots)
        tracker.update_with_events(tracker_events, default_domain)

        with caplog.at_level(logging.DEBUG):
            prediction = loaded_policy.predict_action_probabilities(
                tracker, default_domain, precomputations=None)

        assert "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text

        assert prediction.probabilities == loaded_policy._default_predictions(
            default_domain)
コード例 #4
0
    def test_retry_and_succeed(
        self, database_type: DatabaseType, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture
    ) -> None:
        real_connect = pyodbc.connect
        attempt = 0

        def fail_twice_then_succeed(connection_string: str, timeout: float) -> pyodbc.Connection:
            nonlocal attempt
            if attempt < 2:
                attempt += 1
                raise pyodbc.OperationalError(f"Test Fail {attempt}")
            return real_connect(connection_string, timeout=timeout)

        monkeypatch.setattr(master_config, "db_connection_retry_sleep_seconds", 0.01)
        monkeypatch.setattr(pyodbc, "connect", fail_twice_then_succeed)

        with caplog.at_level(logging.DEBUG):
            database = Database(database_type)
            database.get_existing_table_names()  # Run something that explicitly connects to the database

        assert caplog.messages == [
            f"Trying to connect to {database}, attempt #1",
            f"Unsuccessful attempt #1 to connect to {database}: Test Fail 1",
            "Waiting 0.01 seconds before next connection attempt",
            f"Trying to connect to {database}, attempt #2",
            f"Unsuccessful attempt #2 to connect to {database}: Test Fail 2",
            "Waiting 0.01 seconds before next connection attempt",
            f"Trying to connect to {database}, attempt #3",
            f"Successfully connected to {database}",
        ]
コード例 #5
0
ファイル: test_channels.py プロジェクト: ChenHuaYou/rasa
async def test_socketio_channel_jwt_authentication_invalid_key(
    caplog: LogCaptureFixture,
):
    from rasa.core.channels.socketio import SocketIOInput

    public_key = "random_key123"
    invalid_public_key = "my_invalid_key"
    jwt_algorithm = "HS256"
    invalid_auth_token = jwt.encode(
        {"payload": "value"}, invalid_public_key, algorithm=jwt_algorithm
    )

    input_channel = SocketIOInput(
        # event name for messages sent from the user
        user_message_evt="user_uttered",
        # event name for messages sent from the bot
        bot_message_evt="bot_uttered",
        # socket.io namespace to use for the messages
        namespace=None,
        # public key for JWT methods
        jwt_key=public_key,
        # method used for the signature of the JWT authentication payload
        jwt_method=jwt_algorithm,
    )

    assert input_channel.jwt_key == public_key
    assert input_channel.jwt_algorithm == jwt_algorithm

    with caplog.at_level(logging.ERROR):
        rasa.core.channels.channel.decode_bearer_token(
            invalid_auth_token, input_channel.jwt_key, input_channel.jwt_algorithm
        )

    assert any("JWT public key invalid." in message for message in caplog.messages)
コード例 #6
0
    def test_e2e_gives_experimental_warning(
        self,
        monkeypatch: MonkeyPatch,
        trained_e2e_model: Text,
        domain_path: Text,
        stack_config_path: Text,
        e2e_stories_path: Text,
        nlu_data_path: Text,
        caplog: LogCaptureFixture,
    ):
        mock_nlu_training(monkeypatch)
        mock_core_training(monkeypatch)

        with caplog.at_level(logging.WARNING):
            train(
                domain_path,
                stack_config_path,
                [e2e_stories_path, nlu_data_path],
                output=new_model_path_in_same_dir(trained_e2e_model),
            )

        assert any([
            "The end-to-end training is currently experimental"
            in record.message for record in caplog.records
        ])
コード例 #7
0
    def test_no_connection_when_disabled(
        self, database_type: DatabaseType, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture
    ) -> None:
        monkeypatch.setattr(master_config, "db_connection_attempts", 0)

        with caplog.at_level(logging.INFO):
            database = Database(database_type)
            assert caplog.messages == [f"Skipping {database} connection due to db_connection_attempts==0"]
コード例 #8
0
 def test_driver_can_not_get_status(self, repo: KrakenRepository, mocker: MockerFixture, caplog: LogCaptureFixture
                                    ) -> None:
     # arrange
     error_message = 'USB Communication Error'
     mocker.patch.object(
         repo, '_driver', spec=KrakenX3
     )
     mocker.patch.object(
         repo._driver, 'get_status', side_effect=OSError(error_message)
     )
     mocker.patch.object(repo, 'cleanup')
     caplog.at_level(logging.ERROR)
     # act
     status = repo.get_status()
     # assert
     assert status is None
     assert f'Error getting the status: {error_message}' in caplog.text
     repo.cleanup.assert_called_once()
コード例 #9
0
ファイル: test_utils.py プロジェクト: zeroesones/rasa
def test_validate_with_invalid_directory_if_default_is_valid(
        caplog: LogCaptureFixture):
    tempdir = tempfile.mkdtemp()
    invalid_directory = "gcfhvjkb"

    with caplog.at_level(logging.WARNING, rasa.cli.utils.logger.name):
        assert get_validated_path(invalid_directory, "out", tempdir) == tempdir

    assert "'{}' does not exist".format(invalid_directory) in caplog.text
コード例 #10
0
 def test_status_unknown_driver_type(self, repo: KrakenRepository, mocker: MockerFixture, caplog: LogCaptureFixture
                                     ) -> None:
     # arrange
     mocker.patch.object(
         repo, '_driver',
         # will likely never be supported by gkraken:
         spec=CorsairHidPsu
     )
     mocker.patch.object(
         repo._driver, 'get_status', return_value=[
             ('Fan Speed', 238, 'rpm')
         ]
     )
     caplog.at_level(logging.ERROR)
     # act
     status = repo.get_status()
     # assert
     assert status is None
     assert 'Driver Instance is not recognized' in caplog.text
コード例 #11
0
def test_load_without_training(
    create_or_load_mitie_extractor: Callable[[Dict[Text, Any]],
                                             MitieEntityExtractor],
    caplog: LogCaptureFixture,
):
    with caplog.at_level(logging.DEBUG):
        create_or_load_mitie_extractor({}, load=True)

    assert any(
        "Failed to load MitieEntityExtractor from model storage." in message
        for message in caplog.messages)
コード例 #12
0
ファイル: test_broker.py プロジェクト: yang198876/rasa
def test_no_pika_logs_if_no_debug_mode(caplog: LogCaptureFixture):
    from rasa.core.brokers import pika

    with caplog.at_level(logging.INFO):
        with pytest.raises(Exception):
            pika.initialise_pika_connection("localhost",
                                            "user",
                                            "password",
                                            connection_attempts=1)

    assert len(caplog.records) == 0
コード例 #13
0
def test_store_forecast_with_disabled_database(data_output: DataOutput, caplog: LogCaptureFixture) -> None:
    data_output._internal_database._is_disabled = True

    dummy_forecast = pd.DataFrame()
    model_run = ForecastModelRun()

    with caplog.at_level(logging.DEBUG):
        returned_model_run = data_output._store_forecast_in_internal_database(
            forecast=dummy_forecast, model_run=model_run
        )
    assert returned_model_run is model_run
    assert "Skip storing forecast in forecast_data table because of disabled internal database" in caplog.messages
コード例 #14
0
def test_policy_loading_load_returns_none(tmp_path: Path,
                                          caplog: LogCaptureFixture):
    original_policy_ensemble = PolicyEnsemble([LoadReturnsNonePolicy()])
    original_policy_ensemble.train([], None, RegexInterpreter())
    original_policy_ensemble.persist(str(tmp_path))

    with caplog.at_level(logging.WARNING):
        ensemble = PolicyEnsemble.load(str(tmp_path))
        assert (caplog.records.pop().msg ==
                "Failed to load policy tests.core.test_ensemble."
                "LoadReturnsNonePolicy: load returned None")
        assert len(ensemble.policies) == 0
コード例 #15
0
def test_pika_logs_in_debug_mode(caplog: LogCaptureFixture,
                                 monkeypatch: MonkeyPatch):
    from rasa.core.brokers import pika

    with caplog.at_level(logging.DEBUG):
        with pytest.raises(Exception):
            pika.initialise_pika_connection("localhost",
                                            "user",
                                            "password",
                                            connection_attempts=1)

    assert len(caplog.records) > 0
コード例 #16
0
def test_graph_trainer_train_logging_with_cached_components(
    tmp_path: Path,
    temp_cache: TrainingCache,
    train_with_schema: Callable,
    caplog: LogCaptureFixture,
):
    input_file = tmp_path / "input_file.txt"
    input_file.write_text("3")

    train_schema = GraphSchema({
        "input":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        ),
        "subtract":
        SchemaNode(
            needs={"i": "input"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={"x": 1},
            is_target=True,
            is_input=False,
        ),
        "cache_able_node":
        SchemaNode(
            needs={"suffix": "input"},
            uses=CacheableComponent,
            fn="run",
            constructor_name="create",
            config={},
            is_target=True,
            is_input=False,
        ),
    })

    # Train to cache
    train_with_schema(train_schema, temp_cache)

    # Train a second time
    with caplog.at_level(logging.INFO, logger="rasa.engine.training.hooks"):
        train_with_schema(train_schema, temp_cache)

        assert set(caplog.messages) == {
            "Starting to train component 'SubtractByX'.",
            "Finished training component 'SubtractByX'.",
            "Restored component 'CacheableComponent' from cache.",
        }
コード例 #17
0
    def test_register_not_sampled(self, space: Space,
                                  caplog: LogCaptureFixture):
        """Check that a point cannot registered if not sampled."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            hyperband.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
コード例 #18
0
async def test_no_pika_logs_if_no_debug_mode(caplog: LogCaptureFixture):
    broker = PikaEventBroker("host",
                             "username",
                             "password",
                             retry_delay_in_seconds=1,
                             connection_attempts=1)

    with caplog.at_level(logging.INFO):
        with pytest.raises(Exception):
            await broker.connect()

    # Only Rasa Open Source logs, but logs from the library itself.
    assert all(record.name in ["rasa.core.brokers.pika", "asyncio"]
               for record in caplog.records)
コード例 #19
0
def test_sql_tracker_store_with_login_db_race_condition(
    postgres_login_db_connection: sa.engine.Connection,
    caplog: LogCaptureFixture,
    monkeypatch: MonkeyPatch,
):
    original_execute = sa.engine.Connection.execute

    def mock_execute(self, *args, **kwargs):
        # this simulates a race condition
        if kwargs == {"database_name": POSTGRES_TRACKER_STORE_DB}:
            original_execute(
                self.execution_options(isolation_level="AUTOCOMMIT"),
                f"CREATE DATABASE {POSTGRES_TRACKER_STORE_DB}",
            )
            return Mock(rowcount=0)
        else:
            return original_execute(self, *args, **kwargs)

    with monkeypatch.context() as mp:
        mp.setattr(sa.engine.Connection, "execute", mock_execute)
        with caplog.at_level(logging.ERROR):
            tracker_store = SQLTrackerStore(
                dialect="postgresql",
                host=POSTGRES_HOST,
                port=POSTGRES_PORT,
                username=POSTGRES_USER,
                password=POSTGRES_PASSWORD,
                db=POSTGRES_TRACKER_STORE_DB,
                login_db=POSTGRES_LOGIN_DB,
            )

    # IntegrityError has been caught and we log the error
    assert any(
        [
            f"Could not create database '{POSTGRES_TRACKER_STORE_DB}'" in record.message
            for record in caplog.records
        ]
    )
    matching_rows = (
        postgres_login_db_connection.execution_options(isolation_level="AUTOCOMMIT")
        .execute(
            sa.text(
                "SELECT 1 FROM pg_catalog.pg_database WHERE datname = :database_name"
            ),
            database_name=POSTGRES_TRACKER_STORE_DB,
        )
        .rowcount
    )
    assert matching_rows == 1
    tracker_store.engine.dispose()
コード例 #20
0
def test_ensure_schema_exists(monkeypatch: MonkeyPatch,
                              caplog: LogCaptureFixture) -> None:
    existing_schema = "existing_schema"
    mock_get_schema_names = Mock(return_value=[existing_schema])
    monkeypatch.setattr(MSDialect_pyodbc, "get_schema_names",
                        mock_get_schema_names)

    internal_database = Database(DatabaseType.internal)
    internal_database._database_schema = existing_schema

    with caplog.at_level(logging.INFO):
        ensure_schema_exists(internal_database)
        assert f"Found schema: {existing_schema} in internal database" in caplog.messages

    mock_get_schema_names.assert_called_once()
コード例 #21
0
def test_graph_trainer_train_logging(
    tmp_path: Path,
    temp_cache: TrainingCache,
    train_with_schema: Callable,
    caplog: LogCaptureFixture,
):

    input_file = tmp_path / "input_file.txt"
    input_file.write_text("3")

    train_schema = GraphSchema({
        "input":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        ),
        "subtract 2":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
            is_input=True,
        ),
        "subtract":
        SchemaNode(
            needs={"i": "input"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={"x": 1},
            is_target=True,
            is_input=False,
        ),
    })

    with caplog.at_level(logging.INFO, logger="rasa.engine.training.hooks"):
        train_with_schema(train_schema, temp_cache)

    assert caplog.messages == [
        "Starting to train component 'SubtractByX'.",
        "Finished training component 'SubtractByX'.",
    ]
コード例 #22
0
def test_sql_tracker_store_logs_do_not_show_password(caplog: LogCaptureFixture):
    dialect = "postgresql"
    host = "localhost"
    port = 9901
    db = "some-database"
    username = "******"
    password = "******"

    with caplog.at_level(logging.DEBUG):
        _ = SQLTrackerStore(None, dialect, host, port, db, username, password)

    # the URL in the logs does not contain the password
    assert password not in caplog.text

    # instead the password is displayed as '***'
    assert f"postgresql://{username}:***@{host}:{port}/{db}" in caplog.text
コード例 #23
0
    def test_no_ensure_tables_when_disabled(self, database_type: DatabaseType,
                                            monkeypatch: MonkeyPatch,
                                            caplog: LogCaptureFixture) -> None:
        monkeypatch.setattr(master_config, "db_connection_attempts", 0)

        database = Database(database_type)

        mock_create_all = Mock()
        monkeypatch.setattr(database.schema_base_class.metadata, "create_all",
                            mock_create_all)

        caplog.clear()
        with caplog.at_level(logging.INFO):
            ensure_tables_exist(database)
            assert f"Cannot setup tables, because {database} connection is not available" in caplog.messages

        mock_create_all.assert_not_called()
コード例 #24
0
def test_log_deprecation_warning_with_old_config(caplog: LogCaptureFixture):
    message = Message.build("hi there")

    transformers_nlp = HFTransformersNLP(
        {"model_name": "bert", "model_weights": "bert-base-uncased"}
    )
    transformers_nlp.process(message)

    caplog.set_level(logging.DEBUG)
    lm_tokenizer = LanguageModelTokenizer()
    lm_tokenizer.process(message)
    lm_featurizer = LanguageModelFeaturizer(skip_model_load=True)
    caplog.clear()
    with caplog.at_level(logging.DEBUG):
        lm_featurizer.process(message)

    assert "deprecated component HFTransformersNLP" in caplog.text
コード例 #25
0
ファイル: test_bilou_utils.py プロジェクト: zylhub/rasa
def test_check_consistent_bilou_tagging(
    tags: List[Text],
    expected_tags: List[Text],
    debug_message: Optional[Text],
    caplog: LogCaptureFixture,
):

    with caplog.at_level(logging.DEBUG):
        actual_tags = bilou_utils.ensure_consistent_bilou_tagging(tags)

    if debug_message:
        assert len(caplog.records) > 0
        assert debug_message in caplog.text
    else:
        assert len(caplog.records) == 0

    assert actual_tags == expected_tags
コード例 #26
0
def test_loading_from_storage_fail(
    training_data: TrainingData,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    caplog: LogCaptureFixture,
):
    with caplog.at_level(logging.WARNING):
        loaded = SklearnIntentClassifierGraphComponent.load(
            SklearnIntentClassifierGraphComponent.get_default_config(),
            default_model_storage,
            Resource("test"),
            default_execution_context,
        )
        assert isinstance(loaded, SklearnIntentClassifierGraphComponent)

    assert any("Resource 'test' doesn't exist." in message
               for message in caplog.messages)
コード例 #27
0
ファイル: target_types_test.py プロジェクト: gatesn/pants
def test_entry_point_validation(caplog: LogCaptureFixture) -> None:
    addr = Address("src/python/project")

    with pytest.raises(InvalidFieldException):
        PexEntryPointField(" ", address=addr)
    with pytest.raises(InvalidFieldException):
        PexEntryPointField("modue:func:who_knows_what_this_is", address=addr)
    with pytest.raises(InvalidFieldException):
        PexEntryPointField(":func", address=addr)

    ep = "custom.entry_point:"
    with caplog.at_level(logging.WARNING):
        assert "custom.entry_point" == PexEntryPointField(ep, address=addr).value

    assert len(caplog.record_tuples) == 1
    _, levelno, message = caplog.record_tuples[0]
    assert logging.WARNING == levelno
    assert ep in message
    assert str(addr) in message
コード例 #28
0
async def test_nlg_conditional_response_variations_condition_logging(
    caplog: LogCaptureFixture,
):
    domain = Domain.from_yaml(
        """
        version: "3.0"
        responses:
           utter_action:
             - text: "example"
               condition:
                - type: slot
                  name: test_A
                  value: A
                - type: slot
                  name: test_B
                  value: B
             - text: "default"
        """
    )
    t = TemplatedNaturalLanguageGenerator(domain.responses)
    slot_A = TextSlot(
        name="test_A", mappings=[{}], initial_value="A", influence_conversation=False
    )
    slot_B = TextSlot(
        name="test_B", mappings=[{}], initial_value="B", influence_conversation=False
    )
    tracker = DialogueStateTracker(sender_id="test", slots=[slot_A, slot_B])

    with caplog.at_level(logging.DEBUG):
        await t.generate("utter_action", tracker=tracker, output_channel="")

    assert any(
        "Selecting response variation with conditions:" in message
        for message in caplog.messages
    )
    assert any(
        "[condition 1] type: slot | name: test_A | value: A" in message
        for message in caplog.messages
    )
    assert any(
        "[condition 2] type: slot | name: test_B | value: B" in message
        for message in caplog.messages
    )
コード例 #29
0
def invoke_cli(caplog: LogCaptureFixture,
               cli: click.BaseCommand,
               args: Union[str, Iterable[str], None] = None,
               input: Optional[IO] = None,
               env: Optional[Mapping[str, str]] = None,
               catch_exceptions: bool = True,
               color: bool = False,
               mix_stderr: bool = False,
               **extra: Any) -> Result:
    runner = CliRunner()
    with caplog.at_level(100000):  # click/issues/824 workaround
        out = runner.invoke(cli,
                            args,
                            input,
                            env=env,
                            catch_exceptions=catch_exceptions,
                            color=color,
                            mix_stderr=mix_stderr,
                            **extra)
    return out
コード例 #30
0
async def test_reminder_lock(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    caplog: LogCaptureFixture,
):
    caplog.clear()
    with caplog.at_level(logging.DEBUG):
        sender_id = uuid.uuid4().hex

        reminder = ReminderScheduled("remind", datetime.datetime.now())
        tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

        tracker.update(UserUttered("test"))
        tracker.update(ActionExecuted("action_schedule_reminder"))
        tracker.update(reminder)

        default_processor.tracker_store.save(tracker)

        await default_processor.handle_reminder(reminder, sender_id, default_channel)

        assert f"Deleted lock for conversation '{sender_id}'." in caplog.text