Beispiel #1
0
async def test_sparse_feature_sizes_decreased_incremental_training(
    iter1_path: Text,
    iter2_path: Text,
    should_raise_exception: bool,
    create_response_selector: Callable[
        [Dict[Text, Any]], ResponseSelectorGraphComponent
    ],
    load_response_selector: Callable[[Dict[Text, Any]], ResponseSelectorGraphComponent],
    default_execution_context: ExecutionContext,
    train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
    process_message: Callable[..., Message],
):
    pipeline = [
        {"component": WhitespaceTokenizerGraphComponent},
        {"component": LexicalSyntacticFeaturizerGraphComponent},
        {"component": RegexFeaturizerGraphComponent},
        {"component": CountVectorsFeaturizerGraphComponent},
        {
            "component": CountVectorsFeaturizerGraphComponent,
            "analyzer": "char_wb",
            "min_ngram": 1,
            "max_ngram": 4,
        },
    ]
    training_data, loaded_pipeline = train_and_preprocess(pipeline, iter1_path)

    response_selector = create_response_selector({EPOCHS: 1})
    response_selector.train(training_data=training_data)

    message = Message(data={TEXT: "Rasa is great!"})
    message = process_message(loaded_pipeline, message)

    message2 = copy.deepcopy(message)

    classified_message = response_selector.process([message])[0]

    default_execution_context.is_finetuning = True

    loaded_selector = load_response_selector({EPOCHS: 1})

    classified_message2 = loaded_selector.process([message2])[0]

    assert classified_message2.fingerprint() == classified_message.fingerprint()

    if should_raise_exception:
        with pytest.raises(Exception) as exec_info:
            training_data2, loaded_pipeline2 = train_and_preprocess(
                pipeline, iter2_path
            )
            loaded_selector.train(training_data=training_data2)
        assert "Sparse feature sizes have decreased" in str(exec_info.value)
    else:
        training_data2, loaded_pipeline2 = train_and_preprocess(pipeline, iter2_path)
        loaded_selector.train(training_data=training_data2)
        assert loaded_selector.model
Beispiel #2
0
    def test_epoch_override_when_loaded(
        self,
        trained_policy: TEDPolicy,
        should_finetune: bool,
        epoch_override: int,
        expected_epoch_value: int,
        resource: Resource,
        model_storage: ModelStorage,
        execution_context: ExecutionContext,
    ):
        execution_context.is_finetuning = should_finetune
        loaded_policy = trained_policy.__class__.load(
            {
                **self._config(), EPOCH_OVERRIDE: epoch_override
            },
            model_storage,
            resource,
            execution_context,
        )

        assert loaded_policy.config[EPOCHS] == expected_epoch_value