예제 #1
0
def test_inference(payload):
    save_module_name = const.XLMR_MODULE
    save_model_name = const.XLMR_MULTI_CLASS_MODEL
    const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr"
    const.XLMR_MULTI_CLASS_MODEL = "MockClassifier"
    directory = "/tmp"
    file_path = os.path.join(directory, const.LABELENCODER_FILE)
    if os.path.exists(file_path):
        os.remove(file_path)

    transcripts = payload.get("input")
    intent = payload["expected"]["label"]

    xlmr_clf = XLMRMultiClass(
        model_dir=directory,
        dest="output.intents",
        debug=False,
    )

    merge_asr_output_plugin = MergeASROutputPlugin(dest="input.clf_feature",
                                                   debug=False)

    workflow = Workflow([merge_asr_output_plugin, xlmr_clf])

    train_df = pd.DataFrame([
        {
            "data": json.dumps([[{
                "transcript": "yes"
            }]]),
            "labels": "_confirm_",
        },
        {
            "data": json.dumps([[{
                "transcript": "yea"
            }]]),
            "labels": "_confirm_",
        },
        {
            "data": json.dumps([[{
                "transcript": "no"
            }]]),
            "labels": "_cancel_",
        },
        {
            "data": json.dumps([[{
                "transcript": "nope"
            }]]),
            "labels": "_cancel_",
        },
    ])

    workflow.train(train_df)
    assert isinstance(
        xlmr_clf.model,
        MockClassifier), "model should be a MockClassifier after training."

    _, output = workflow.run(input_=Input(utterances=[[{
        "transcript": transcript
    } for transcript in transcripts]]))
    assert output[const.INTENTS][0]["name"] == intent
    assert output[const.INTENTS][0]["score"] > 0.9

    if os.path.exists(file_path):
        os.remove(file_path)
    const.XLMR_MODULE = save_module_name
    const.XLMR_MULTI_CLASS_MODEL = save_model_name
예제 #2
0
def test_inference(payload):
    directory = "/tmp"
    file_path = os.path.join(directory, const.MLPMODEL_FILE)
    if os.path.exists(file_path):
        os.remove(file_path)

    USE = "use"
    fake_args = {
        const.TRAIN: {
            const.NUM_TRAIN_EPOCHS: 5,
            const.USE_GRIDSEARCH: {
                USE: False,
                const.CV: 2,
                const.VERBOSE_LEVEL: 2,
                const.PARAMS: {
                    "activation": ["relu", "tanh"],
                    "hidden_layer_sizes": [(10, ), (2, 2)],
                    "ngram_range": [(1, 1), (1, 2)],
                    "max_iter": [20, 2],
                },
            },
        },
        const.TEST: {},
        const.PRODUCTION: {},
    }

    transcripts = payload.get("input")
    intent = payload["expected"]["label"]

    mlp_clf = MLPMultiClass(
        model_dir=directory,
        dest="output.intents",
        args_map=fake_args,
        debug=False,
    )

    merge_asr_output_plugin = MergeASROutputPlugin(
        dest="input.clf_feature",
        debug=False,
    )

    workflow = Workflow([merge_asr_output_plugin, mlp_clf])

    train_df = pd.DataFrame([
        {
            "data": json.dumps([[{
                "transcript": "yes"
            }]]),
            "labels": "_confirm_",
        },
        {
            "data": json.dumps([[{
                "transcript": "ye"
            }]]),
            "labels": "_confirm_",
        },
        {
            "data": json.dumps([[{
                "transcript": "<s> yes </s> <s> ye </s>"
            }]]),
            "labels": "_confirm_",
        },
        {
            "data": json.dumps([[{
                "transcript": "no"
            }]]),
            "labels": "_cancel_",
        },
        {
            "data": json.dumps([[{
                "transcript": "new"
            }]]),
            "labels": "_cancel_",
        },
        {
            "data": json.dumps([[{
                "transcript": "<s> new </s> <s> no </s>"
            }]]),
            "labels": "_cancel_",
        },
    ])

    workflow.train(train_df)
    _, output = workflow.run(
        Input(utterances=[[{
            "transcript": transcript
        } for transcript in transcripts]]))
    assert output[const.INTENTS][0]["name"] == intent
    assert output[const.INTENTS][0]["score"] > 0.5
    if os.path.exists(file_path):
        os.remove(file_path)