예제 #1
0
파일: rasa.py 프로젝트: nan0tube/rasa_nlu
    def read_from_json(self, js, **kwargs):
        """Loads training data stored in the rasa NLU data format."""
        validate_rasa_nlu_data(js)

        data = js['rasa_nlu_data']
        common_examples = data.get("common_examples", [])
        intent_examples = data.get("intent_examples", [])
        entity_examples = data.get("entity_examples", [])
        entity_synonyms = data.get("entity_synonyms", [])
        regex_features = data.get("regex_features", [])

        entity_synonyms = transform_entity_synonyms(entity_synonyms)

        if intent_examples or entity_examples:
            logger.warn("DEPRECATION warning: your rasa data "
                        "contains 'intent_examples' "
                        "or 'entity_examples' which will be "
                        "removed in the future. Consider "
                        "putting all your examples "
                        "into the 'common_examples' section.")

        all_examples = common_examples + intent_examples + entity_examples
        training_examples = []
        for ex in all_examples:
            msg = Message.build(ex['text'], ex.get("intent"),
                                ex.get("entities"))
            training_examples.append(msg)

        return TrainingData(training_examples, entity_synonyms, regex_features)
예제 #2
0
파일: rasa.py 프로젝트: hee2000/rasa_nlu
    def read_from_json(self, js, **kwargs):
        """Loads training data stored in the rasa NLU data format."""
        validate_rasa_nlu_data(js)

        data = js['rasa_nlu_data']
        common_examples = data.get("common_examples", [])
        intent_examples = data.get("intent_examples", [])
        entity_examples = data.get("entity_examples", [])
        entity_synonyms = data.get("entity_synonyms", [])
        regex_features = data.get("regex_features", [])
        lookup_tables = data.get("lookup_tables", [])

        entity_synonyms = transform_entity_synonyms(entity_synonyms)

        if intent_examples or entity_examples:
            logger.warning("DEPRECATION warning: your rasa data "
                           "contains 'intent_examples' "
                           "or 'entity_examples' which will be "
                           "removed in the future. Consider "
                           "putting all your examples "
                           "into the 'common_examples' section.")

        all_examples = common_examples + intent_examples + entity_examples
        training_examples = []
        for ex in all_examples:
            msg = Message.build(ex['text'], ex.get("intent"),
                                ex.get("entities"))
            training_examples.append(msg)

        return TrainingData(training_examples, entity_synonyms, regex_features,
                            lookup_tables)
예제 #3
0
파일: dataset.py 프로젝트: rikhuijzer/bench
def create_message(text: str, intent: str, entities: [], training: bool,
                   corpus: tp.Corpus) -> Message:
    """ Helper function to create a message: Message used by Rasa including whether train or test sentence. """
    message = Message.build(text, intent, entities)
    message.data['training'] = training
    message.data['corpus'] = corpus
    return message
예제 #4
0
    def _read_intent(self, intent_js, examples_js):
        """Reads the intent and examples from respective jsons."""
        intent = intent_js.get("name")

        training_examples = []
        for ex in examples_js:
            text, entities = self._join_text_chunks(ex['data'])
            training_examples.append(Message.build(text, intent, entities))

        return TrainingData(training_examples)
예제 #5
0
    def _read_intent(self, intent_js, examples_js):
        """Reads the intent and examples from respective jsons."""
        intent = intent_js.get("name")

        training_examples = []
        for ex in examples_js:
            text, entities = self._join_text_chunks(ex['data'])
            training_examples.append(Message.build(text, intent, entities))

        return TrainingData(training_examples)
예제 #6
0
def convert_line_message(line: List[str]) -> Message:
    """Return message without entities. Good enough for now."""
    if len(line) < 2:
        raise AssertionError(
            'Line should be a list containing at least [sentence, intent], got {}'
            .format(line))

    message = Message.build(text=convert_annotated_text(line[0]),
                            intent=line[1],
                            entities=[])
    if len(line) == 3:
        message.data['training'] = line[2] == 'True'
    return message
예제 #7
0
def train_update(update, by):
    update.start_training(by)

    examples = [
        Message.build(
            text=example.get_text(update.language),
            intent=example.intent,
            entities=[
                example_entity.rasa_nlu_data
                for example_entity in example.get_entities(update.language)
            ]) for example in update.examples
    ]

    label_examples_query = update.examples \
        .filter(entities__entity__label__isnull=False) \
        .annotate(entities_count=models.Count('entities')) \
        .filter(entities_count__gt=0)

    label_examples = [
        Message.build(
            text=example.get_text(update.language),
            entities=[
                example_entity.get_rasa_nlu_data(label_as_entity=True)
                for example_entity in filter(
                    lambda ee: ee.entity.label,
                    example.get_entities(update.language))
            ]) for example in label_examples_query
    ]

    rasa_nlu_config = get_rasa_nlu_config_from_update(update)
    trainer = Trainer(rasa_nlu_config, ComponentBuilder(use_cache=False))
    training_data = BothubTrainingData(label_training_examples=label_examples,
                                       training_examples=examples)
    trainer.train(training_data)
    persistor = BothubPersistor(update)
    trainer.persist(mkdtemp(),
                    persistor=persistor,
                    project_name=str(update.repository.uuid),
                    fixed_model_name=str(update.id))