コード例 #1
0
    def loadModels(self, force=False):
        config = RasaNLUConfig(self.default_config_file)
        slotConfig = RasaNLUConfig(self.default_config_file_slots)
        #print("start load modesls")
        # use the files as templates and override config on top for each model
        for model in self.models:

            config.override({"project": "nlu", "fixed_model_name": model})
            slotConfig.override({
                "project": "nlu",
                "fixed_model_name": "{}_slots".format(model)
            })

            if not self.isNluModelMissing(model):
                #print("model not missing")
                #if (self.hasTrainingMaterials(model)):
                #self.sendTrainingRequest(model)
                #else:
                #print('missing model and training data')
                #else:
                ## load trained model
                # allow that training id is None for load existing model
                if (self.isNluModelModified(model) or force):
                    print("loading nlu model {}".format(model))
                    self.nlu_model_modified[model] = self.getNluModelModified(
                        model)
                    self.nlu_modified[model] = self.getNluModified(model)
                    self.interpreter[model] = Interpreter.load(
                        "{}/{}/".format(self.nlu_model_path, model), config)
                    self.interpreter_slots[model] = Interpreter.load(
                        "{}/{}_slots/".format(self.nlu_model_path, model),
                        slotConfig)

                    print('loaded nlu model {}'.format(model))
                if model in self.trainingIds and self.trainingIds[
                        model] is not None:
                    self.training_client.unsubscribe(
                        "hermes/training/complete/{}".format(
                            self.trainingIds[model]))
                    self.training_request_models[
                        self.trainingIds[model]] = None
                    self.trainingIds[model] = None
コード例 #2
0
ファイル: training_server.py プロジェクト: mramshaw/opensnips
    def do_training(self, run_event):
        while True and run_event.is_set():
            #print('do training')
            if len(self.queue) > 0:
                msg = self.queue.pop(0)
                payload = json.loads(msg.payload)
                theId = payload.get('id')
                print('HANDLE TRAINING REQUEST {}'.format(theId))
                if theId is not None:
                    # .decode('utf-8')
                    trainingType = payload.get('type')
                    # DO THE TRAINING
                    if trainingType == "rasanlu":
                        # set variables from msg payload or defaults
                        trainingConfig = payload.get('config')
                        #if trainingConfig is None:
                        #with open('rasa_config/config.json', 'r') as content_file:
                        #trainingConfig = content_file.read()
                        trainingConfig = json.loads(trainingConfig)

                        slotConfig = payload.get('config_slots')
                        #if slotConfig is None:
                        #with open('rasa_config/config-slots.json', 'r') as content_file:
                        #slotConfig = content_file.read()
                        slotConfig = json.loads(slotConfig)
                        project = payload.get('project', 'nlu')
                        model = payload.get('model', 'default')
                        training_data = payload.get('examples', '')

                        tmpdir = tempfile.mkdtemp()

                        # train intents and slots
                        rasaConfig = RasaNLUConfig()
                        rasaConfig.override(trainingConfig)
                        trainer = Trainer(rasaConfig)
                        data = SnipsMarkdownToJson(training_data)
                        trainer.train(
                            TrainingData(
                                data.common_examples,
                                get_entity_synonyms_dict(
                                    data.entity_synonyms)))
                        model_directory = trainer.persist(
                            tmpdir,
                            project_name=project,
                            fixed_model_name=model)

                        # train slots only for partial query
                        rasaConfig.override(slotConfig)
                        trainer = Trainer(rasaConfig)
                        data = SnipsMarkdownToJson(training_data)
                        trainer.train(
                            TrainingData(
                                data.common_examples,
                                get_entity_synonyms_dict(
                                    data.entity_synonyms)))
                        model_directory = trainer.persist(
                            tmpdir,
                            project_name=project,
                            fixed_model_name="{}_slots".format(model))
                        # send mqtt training/complete
                        self.send_training_complete(
                            theId, "{}/{}".format(tmpdir, project))
                        #shutil.rmtree(tmpdir)

                    elif trainingType == "rasacore":
                        print("training rasa core")
                        training_data = payload.get(
                            'stories')  #,SAMPLE_STORIES)
                        trainingFile = tempfile.NamedTemporaryFile(
                            delete=False, suffix='.md')
                        trainingFile.write(training_data)
                        trainingFile.close()
                        #print((training_data))
                        domain_data = payload.get('domain')
                        #if domain_data is None:
                        #with open('rasa_config/domain.yml', 'r') as content_file:
                        #domain_data = content_file.read()

                        #print(domain_data)
                        tmpdir = tempfile.mkdtemp()
                        domainFile = io.open(
                            os.path.join(tmpdir, "domain.yml"), "w")
                        #domainFile = tempfile.NamedTemporaryFile(delete = False,suffix='.yml')
                        domainFile.write(domain_data)
                        domainFile.close()
                        agent = SnipsMqttAgent.createAgent(tmpdir)
                        #agent = Agent(domainFile.name,policies=[MemoizationPolicy(), KerasPolicy()])
                        agent.train(
                            trainingFile.name,
                            max_history=3,
                            epochs=30,  #was 100
                            batch_size=50,
                            augmentation_factor=50,
                            validation_split=0.2)
                        print("traing rasa core done train")
                        # cleanup
                        domainFile.close()
                        os.unlink(domainFile.name)
                        os.unlink(trainingFile.name)
                        # persist
                        modelPath = tempfile.mkdtemp()
                        agent.persist(modelPath)
                        self.send_training_complete(theId, modelPath)
                        print("traing rasa core sent")

                    elif trainingType == "kaldi":
                        pass
                    elif trainingType == "piwho":
                        pass
                    elif trainingType == "snowboy":
                        pass
                    # I WISH
                    #elif trainingType == "snips":
                    #pass
                else:
                    print("Required ID missing in training request")
            time.sleep(5)