コード例 #1
0
    def _recurse(
        self,
        registry: Registry,
        random: Generator,
        language: str,
        this: DocumentPlanNode,
        previous_entities: DefaultDict[str, None],
        encountered: Set[str],
    ) -> Tuple[Set[str], DefaultDict[str, None]]:
        """
        Traverses the DocumentPlan tree recursively in-order and modifies named
        entity to_value functions to return the chosen form of that NE's name.
        """
        if isinstance(this, Slot):
            if not self.is_entity(this.value):
                log.debug("Visited non-NE leaf node {}".format(this.value))
                return encountered, previous_entities

            log.debug("Visiting NE leaf {}".format(this.value))
            entity_type, entity = self.parse_entity(this.value)

            if previous_entities[entity_type] == entity:
                log.debug("Same as previous entity")
                this.attributes["name_type"] = "pronoun"

            elif entity in encountered:
                log.debug(
                    "Different entity than previous, but has been previously encountered"
                )
                this.attributes["name_type"] = "short"

            else:
                log.debug("First time encountering this entity")
                this.attributes["name_type"] = "full"
                encountered.add(entity)
                log.debug(
                    "Added entity to encountered, all encountered: {}".format(
                        encountered))

            self.resolve_surface_form(registry, random, language, this, entity,
                                      entity_type)
            log.debug("Resolved entity name")

            this.attributes["entity_type"] = entity_type
            previous_entities[entity_type] = entity

            return encountered, previous_entities
        elif isinstance(this, DocumentPlanNode):
            log.debug("Visiting non-leaf '{}'".format(this))
            for child in this.children:
                encountered, previous_entities = self._recurse(
                    registry, random, language, child, previous_entities,
                    encountered)
            return encountered, previous_entities
        return encountered, previous_entities
コード例 #2
0
ファイル: aggregator.py プロジェクト: ljleppan/eu-nlg-prod
    def run(self, registry: Registry, random: Generator, language: str,
            document_plan: DocumentPlanNode) -> Tuple[DocumentPlanNode]:
        if log.isEnabledFor(logging.DEBUG):
            document_plan.print_tree()

        log.debug("Aggregating")
        self._aggregate(registry, language, document_plan)

        if log.isEnabledFor(logging.DEBUG):
            document_plan.print_tree()

        return (document_plan, )
コード例 #3
0
    def _recurse(
        self,
        registry: Registry,
        random: Generator,
        language: str,
        this: DocumentPlanNode,
    ):
        """
        Traverses the DocumentPlan tree recursively in-order and modifies named
        entity to_value functions to return the chosen form of that NE's name.
        """
        language_specific_realizers = self.realizers.get(language, {})
        if isinstance(this, Slot):
            this = cast(Slot, this)
            if this.attributes and this.attributes.get("ord"):
                realizer = language_specific_realizers.get("ord")
                if not realizer:
                    log.error(
                        "Wanted to realize as ordinal '{}' but found no realizer."
                        .format(this.value))
                else:
                    new_value = realizer.realize(this)
                    this.value = lambda x: new_value

        elif isinstance(this, DocumentPlanNode):
            log.debug("Visiting non-leaf '{}'".format(this))
            for child in this.children:
                self._recurse(registry, random, language, child)
コード例 #4
0
def realize_message(message: Message, template: Template,
                    language: str) -> Message:
    rnd = np.random.default_rng(42)

    # Disable logging for a while
    old_log_level = log.level
    log.setLevel(logging.WARNING)

    # Accessing a non-public method, but I can't be bothered to re-engineer the whole class.
    TEMPLATE_SELECTOR._add_template_to_message(message, template, [message])

    # Realize the single message + template pair
    doc_plan = DocumentPlanNode([message])
    (doc_plan, ) = SLOT_REALIZER.run(SERVICE.registry, rnd, language, doc_plan)
    (doc_plan, ) = DATE_REALIZER.run(SERVICE.registry, rnd, language, doc_plan)
    (doc_plan, ) = ENTITY_NAME_RESOLVER.run(SERVICE.registry, rnd, language,
                                            doc_plan)
    (doc_plan, ) = NUMBER_REALIZER.run(SERVICE.registry, rnd, language,
                                       doc_plan)
    (doc_plan, ) = MORPHOLOGICAL_REALIZER.run(SERVICE.registry, rnd, language,
                                              doc_plan)

    # Re-enable logging
    log.setLevel(old_log_level)

    msg: Message = doc_plan.children[0]
    return msg
コード例 #5
0
    def run(self, registry: Registry, random: Generator, language: str,
            document_plan: DocumentPlanNode) -> Tuple[DocumentPlanNode]:
        """
        Run this pipeline component.
        """
        log.info("Realizing dates")

        if language.endswith("-head"):
            language = language[:-5]
            log.debug(
                "Language had suffix '-head', removing. Result: {}".format(
                    language))

        self._recurse(registry, random, language, document_plan)

        if log.isEnabledFor(logging.DEBUG):
            document_plan.print_tree()

        return (document_plan, )
コード例 #6
0
    def run(self, registry: Registry, random: Generator, language: str,
            document_plan: DocumentPlanNode) -> Tuple[DocumentPlanNode]:
        """
        Run this pipeline component.
        """
        log.info("Running NER")

        if language.endswith("-head"):
            language = language[:-5]
            log.debug(
                "Language had suffix '-head', removing. Result: {}".format(
                    language))

        previous_entities = defaultdict(lambda: None)
        self._recurse(registry, random, language, document_plan,
                      previous_entities, set())

        if log.isEnabledFor(logging.DEBUG):
            document_plan.print_tree()

        return (document_plan, )
コード例 #7
0
    def run(
        self,
        registry: Registry,
        random: RandomState,
        language: str,
        core_messages: List[Message],
        expanded_messages: List[Message],
    ):
        """
        Runs this pipeline component.
        """
        template_selector = TemplateSelector()
        slot_realizer = SlotRealizer()
        date_realizer = LanguageSplitComponent(
            {
                "en": EnglishEUDateRealizer(),
                "fi": FinnishEUDateRealizer(),
                "hr": CroatianEUDateRealizer(),
                "de": GermanEUDateRealizer(),
                "ru": RussianEUDateRealizer(),
                "ee": EstonianEUDateRealizer(),
                "sl": SlovenianEUDateRealizer(),
            }
        )
        number_realizer = EUNumberRealizer()
        entity_name_resolver = EUEntityNameResolver()

        root_logger = logging.getLogger()
        original_log_level = root_logger.level
        log.info(
            f"Setting root log level to WARNING (={logging.WARNING}) temporarily (was {original_log_level}), "
            f"because we're going to produce hella spam by running the first half of the pipeline at least a few "
            f"thousand times."
        )
        root_logger.setLevel(logging.WARNING)
        # i = 0
        # start = time.time()
        for msg in itertools.chain(core_messages, expanded_messages):
            doc_plan = DocumentPlanNode([msg])
            template_selector.run(registry, random, language, doc_plan, core_messages)
            slot_realizer.run(registry, random, language, doc_plan)
            date_realizer.run(registry, random, language, doc_plan)
            entity_name_resolver.run(registry, random, language, doc_plan)
            number_realizer.run(registry, random, language, doc_plan)

        root_logger.setLevel(original_log_level)
        log.info(f"Log level restored to {original_log_level}")
        return core_messages, expanded_messages
コード例 #8
0
    def _recurse(
        self,
        registry: Registry,
        random: Generator,
        language: str,
        this: DocumentPlanNode,
        previous_entity: Optional[str],
    ) -> Optional[str]:
        """
        Traverses the DocumentPlan tree recursively in-order and modifies named
        entity to_value functions to return the chosen form of that NE's name.
        """
        idx = 0
        while idx < len(this.children):
            child = this.children[idx]
            if isinstance(child, Slot):
                if not isinstance(
                        child.value, str
                ) or child.value[0] != "[" or child.value[-1] != "]":
                    log.debug("Visited non-tag leaf node {}".format(
                        child.value))
                    idx += 1
                    continue

                segments = child.value[1:-1].split(":")
                if segments[0] != "TIME":
                    log.debug("Visited non-TIME leaf node {}".format(
                        child.value))
                    idx += 1
                    continue

                timestamp_type = segments[1]
                if timestamp_type == "month":
                    new_value = self._realize_month(child, previous_entity)
                elif timestamp_type == "year":
                    new_value = self._realize_year(child, previous_entity)
                else:
                    log.error(
                        "Visited TIME leaf node {} but couldn't realize it!".
                        format(child.value))
                    idx + 1
                    continue

                if isinstance(new_value, list):
                    new_value = random.choice(new_value)

                original_value = child.value
                new_components = []
                for component_idx, realization_token in enumerate(
                        new_value.split()):
                    new_slot = child.copy(include_fact=True)

                    # By default, copy copies the attributes too. In case attach_attributes_to was set,
                    # we need to explicitly reset the attributes for all those slots NOT explicitly mentioned
                    if (self.attach_attributes
                            and timestamp_type in self.attach_attributes
                            and component_idx
                            not in self.attach_attributes[timestamp_type]):
                        new_slot.attributes = {}

                    # An ugly hack that ensures the lambda correctly binds to the value of realization_token at this
                    # time. Without this, all the lambdas bind to the final value of the realization_token variable, ie.
                    # the final value at the end of the loop.  See https://stackoverflow.com/a/10452819
                    new_slot.value = lambda f, realization_token=realization_token: realization_token
                    new_components.append(new_slot)

                this.children[idx:idx + 1] = new_components
                idx += len(new_components)
                log.debug(
                    "Visited TIME leaf node {} and realized it as {}".format(
                        original_value, new_value))
                previous_entity = original_value
            elif isinstance(child, DocumentPlanNode):
                log.debug("Visiting non-leaf '{}'".format(child))
                previous_entity = self._recurse(registry, random, language,
                                                child, previous_entity)
                idx += 1
            else:
                # Neither DocumentPlan nor Slot, must be f.ex. Literal -> skip.
                idx += 1
        return previous_entity