def test_inference(payload): save_module_name = const.XLMR_MODULE save_model_name = const.XLMR_MULTI_CLASS_MODEL const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr" const.XLMR_MULTI_CLASS_MODEL = "MockClassifier" directory = "/tmp" file_path = os.path.join(directory, const.LABELENCODER_FILE) if os.path.exists(file_path): os.remove(file_path) transcripts = payload.get("input") intent = payload["expected"]["label"] xlmr_clf = XLMRMultiClass( model_dir=directory, dest="output.intents", debug=False, ) merge_asr_output_plugin = MergeASROutputPlugin(dest="input.clf_feature", debug=False) workflow = Workflow([merge_asr_output_plugin, xlmr_clf]) train_df = pd.DataFrame([ { "data": json.dumps([[{ "transcript": "yes" }]]), "labels": "_confirm_", }, { "data": json.dumps([[{ "transcript": "yea" }]]), "labels": "_confirm_", }, { "data": json.dumps([[{ "transcript": "no" }]]), "labels": "_cancel_", }, { "data": json.dumps([[{ "transcript": "nope" }]]), "labels": "_cancel_", }, ]) workflow.train(train_df) assert isinstance( xlmr_clf.model, MockClassifier), "model should be a MockClassifier after training." _, output = workflow.run(input_=Input(utterances=[[{ "transcript": transcript } for transcript in transcripts]])) assert output[const.INTENTS][0]["name"] == intent assert output[const.INTENTS][0]["score"] > 0.9 if os.path.exists(file_path): os.remove(file_path) const.XLMR_MODULE = save_module_name const.XLMR_MULTI_CLASS_MODEL = save_model_name
def test_inference(payload): directory = "/tmp" file_path = os.path.join(directory, const.MLPMODEL_FILE) if os.path.exists(file_path): os.remove(file_path) USE = "use" fake_args = { const.TRAIN: { const.NUM_TRAIN_EPOCHS: 5, const.USE_GRIDSEARCH: { USE: False, const.CV: 2, const.VERBOSE_LEVEL: 2, const.PARAMS: { "activation": ["relu", "tanh"], "hidden_layer_sizes": [(10, ), (2, 2)], "ngram_range": [(1, 1), (1, 2)], "max_iter": [20, 2], }, }, }, const.TEST: {}, const.PRODUCTION: {}, } transcripts = payload.get("input") intent = payload["expected"]["label"] mlp_clf = MLPMultiClass( model_dir=directory, dest="output.intents", args_map=fake_args, debug=False, ) merge_asr_output_plugin = MergeASROutputPlugin( dest="input.clf_feature", debug=False, ) workflow = Workflow([merge_asr_output_plugin, mlp_clf]) train_df = pd.DataFrame([ { "data": json.dumps([[{ "transcript": "yes" }]]), "labels": "_confirm_", }, { "data": json.dumps([[{ "transcript": "ye" }]]), "labels": "_confirm_", }, { "data": json.dumps([[{ "transcript": "<s> yes </s> <s> ye </s>" }]]), "labels": "_confirm_", }, { "data": json.dumps([[{ "transcript": "no" }]]), "labels": "_cancel_", }, { "data": json.dumps([[{ "transcript": "new" }]]), "labels": "_cancel_", }, { "data": json.dumps([[{ "transcript": "<s> new </s> <s> no </s>" }]]), "labels": "_cancel_", }, ]) workflow.train(train_df) _, output = workflow.run( Input(utterances=[[{ "transcript": transcript } for transcript in transcripts]])) assert output[const.INTENTS][0]["name"] == intent assert output[const.INTENTS][0]["score"] > 0.5 if os.path.exists(file_path): os.remove(file_path)