Ejemplo n.º 1
0
    def _combine(self, registry: Registry, language: str, first: Message,
                 second: Message) -> Message:
        log.debug("Combining two templates:")
        log.debug("\t{}".format([c.value for c in first.template.components]))
        log.debug("\t{}".format([c.value for c in second.template.components]))

        shared_prefix = self._get_combinable_prefix(first, second)
        log.debug(f"Shared prefix is {[e.value for e in shared_prefix]}")
        combined = [c for c in first.template.components]

        # TODO At the moment everything is considered either positive or negative, which is sometimes weird.
        #  Add neutral sentences.
        conjunctions = registry.get("conjunctions").get(language, None)
        if not conjunctions:
            conjunctions = (defaultdict(lambda x: "NO-CONJUNCTION-DICT"), )

        if first.polarity != first.polarity:
            combined.append(
                Literal(
                    conjunctions.get("inverse_combiner",
                                     "MISSING-INVERSE-CONJUCTION")))
        else:
            combined.append(
                Literal(
                    conjunctions.get("default_combiner",
                                     "MISSING-DEFAULT-CONJUCTION")))
        combined.extend(second.template.components[len(shared_prefix):])
        log.debug("Combined thing is {}".format([c.value for c in combined]))
        new_message = Message(
            facts=first.facts +
            [fact for fact in second.facts if fact not in first.facts],
            importance_coefficient=first.importance_coefficient,
        )
        new_message.template = Template(combined)
        new_message.prevent_aggregation = True
        return new_message
Ejemplo n.º 2
0
    def run(
        self,
        registry: Registry,
        random: RandomState,
        language: str,
        location_query: str,
        location_type_query: str,
        dataset: str,
        previous_location: str,
        ignored_cols: Optional[List[str]] = None,
    ) -> Tuple[List[Message], List[Message], List[Message]]:
        log.info(
            "Generating messages with location={}, location_type={}, data={}, previous_location={}"
            .format(location_query, location_type_query, dataset,
                    previous_location))

        data_store: DataFrameStore = registry.get("{}-data".format(dataset))
        log.debug("Underlying DataFrame is of size {}".format(
            data_store.all().shape))

        if ignored_cols is None:
            ignored_cols = []

        if location_query == "all":
            core_df = data_store.all()
            expanded_df = None
        elif self.expand:
            log.debug('Query: "{}"'.format(
                "location=={!r}".format(location_query)))
            core_df = data_store.query("location=={!r}".format(location_query))
            expanded_df = data_store.query(
                "location!={!r}".format(location_query))
        else:
            log.debug('Query: "{}"'.format(
                "location=={!r}".format(location_query)))
            core_df = data_store.query("location=={!r}".format(location_query))
            expanded_df = None
        log.debug("Resulting DataFrames are of sizes {} and {}".format(
            core_df.shape,
            "empty" if expanded_df is None else expanded_df.shape))

        core_messages: List[Message] = []
        expanded_messages: List[Message] = []
        col_names = core_df
        col_names = [
            col_name for col_name in col_names if not (col_name in [
                "location", "location_type", "timestamp", "timestamp_type",
                "agent", "agent_type"
            ] or col_name in ignored_cols or ":outlierness" in col_name)
        ]
        core_df.apply(self._gen_messages,
                      axis=1,
                      args=(col_names, core_messages))
        if expanded_df is not None:
            expanded_df.apply(self._gen_messages,
                              axis=1,
                              args=(col_names, expanded_messages))

        if log.getEffectiveLevel() <= 5:
            for m in core_messages:
                log.debug("Extracted CORE message {}".format(m.main_fact))
            for m in expanded_messages:
                log.debug("Extracted EXPANDED message {}".format(m.main_fact))

        log.info(
            "Extracted total {} core messages and {} expanded messages".format(
                len(core_messages), len(expanded_messages)))
        if not core_messages:
            raise NoMessagesForSelectionException("No core messages")

        # Remove all but 10k most interesting expanded messages
        expanded_messages = sorted(expanded_messages,
                                   key=lambda msg: msg.score,
                                   reverse=True)[:10_000]
        log.info(f"Filtered expanded messages to top {len(expanded_messages)}")

        if previous_location:
            log.info("Have previous_location, generating stuff for that")
            previous_location_messages = self._gen_messages_for_previous_location(
                registry, language, location_query, dataset, previous_location)
            log.info("Finished generating previous location related things")
        else:
            previous_location_messages = []

        return core_messages, expanded_messages, previous_location_messages
Ejemplo n.º 3
0
class EUNlgService:
    def __init__(
        self,
        random_seed: Optional[int] = None,
        force_cache_refresh: bool = False,
        nomorphi: bool = False,
        planner: str = "full",
    ) -> None:
        """
        :param random_seed: seed for random number generation, for repeatability
        :param force_cache_refresh:
        :param nomorphi: don't load Omorphi for morphological generation. This removes the dependency on Omorphi,
            so allows easier setup, but means that no morphological inflection will be performed on the output,
            which is generally a very bad thing for the full pipeline
        """

        # New registry and result importer
        self.registry = Registry()

        # DataSets
        DATA_ROOT = Path(__file__).parent.absolute() / ".." / "data"

        self.datasets = [
            "cphi",
            "health_cost",
            "health_funding",
        ]
        for dataset in self.datasets:
            cache_path: Path = (DATA_ROOT / "{}.cache".format(dataset)).absolute()
            if not cache_path.exists():
                raise IOError("No cached dataset found at {}. Datasets must be generated before startup.")
            self.registry.register("{}-data".format(dataset), DataFrameStore(str(cache_path)))

        # Resources
        self.resources = [
            CPHIEnglishResource(),
            CPHIFinnishResource(),
            CPHICroatianResource(),
            CPHIRussianResource(),
            CPHIEstonianResource(),
            CPHISlovenianResource(),
            ENVEnglishResource(),
            ENVFinnishResource(),
            HealthCostEnglishResource(),
            HealthCostFinnishResource(),
            HealthFundingEnglishResource(),
            HealthFundingFinnishResource(),
        ]

        # Templates
        self.registry.register("templates", self._load_templates())

        # Slot Realizers:
        realizers: List[SlotRealizerComponent] = []
        for resource in self.resources:
            for realizer in resource.slot_realizer_components():
                realizers.append(realizer(self.registry))
        self.registry.register("slot-realizers", realizers)

        # Language metadata
        self.registry.register("conjunctions", CONJUNCTIONS)
        self.registry.register("errors", ERRORS)

        # PRNG seed
        self._set_seed(seed_val=random_seed)

        def _get_components(headline=False, planner="full"):
            # Put together the list of components
            # This varies depending on whether it's for headlines and which language we are doing stuff in
            yield EUMessageGenerator(expand=True)
            yield EUImportanceSelector()
            if planner == "random":
                yield EURandomHeadlineDocumentPlanner() if headline else EURandomBodyDocumentPlanner()
            elif planner == "score":
                yield EUScoreHeadlineDocumentPlanner() if headline else EUScoreBodyDocumentPlanner()
            elif planner == "earlystop":
                yield EUEarlyStopHeadlineDocumentPlanner() if headline else EUEarlyStopBodyDocumentPlanner()
            elif planner == "topicsim":
                yield EUTopicSimHeadlineDocumentPlanner() if headline else EUTopicSimBodyDocumentPlanner()
            elif planner == "contextsim":
                yield EUContextSimHeadlineDocumentPlanner() if headline else EUContextSimBodyDocumentPlanner()
            elif planner == "neuralsim":
                if headline:
                    yield EUHeadlineDocumentPlanner()
                else:
                    yield TemplateAttacher()
                    yield EUNeuralSimBodyDocumentPlanner()
                    yield EmbeddingRemover()

            elif planner == "full":
                yield EUHeadlineDocumentPlanner() if headline else EUBodyDocumentPlanner()
            else:
                raise ValueError("INCORRECT PLANNER SETTING")
            yield TemplateSelector()
            yield Aggregator()
            yield SlotRealizer()
            yield LanguageSplitComponent(
                {
                    "en": EnglishEUDateRealizer(),
                    "fi": FinnishEUDateRealizer(),
                    "hr": CroatianEUDateRealizer(),
                    "de": GermanEUDateRealizer(),
                    "ru": RussianEUDateRealizer(),
                    "ee": EstonianEUDateRealizer(),
                    "sl": SlovenianEUDateRealizer(),
                }
            )
            yield EUEntityNameResolver()
            yield EUNumberRealizer()
            yield MorphologicalRealizer(
                {
                    "en": EnglishUralicNLPMorphologicalRealizer(),
                    "fi": FinnishUralicNLPMorphologicalRealizer(),
                    "hr": CroatianSimpleMorphologicalRealizer(),
                    "ru": RussianMorphologicalRealizer(),
                    "ee": EstonianUralicNLPMorphologicalRealizer(),
                    "sl": SlovenianSimpleMorphologicalRealizer(),
                }
            )
            yield HeadlineHTMLSurfaceRealizer() if headline else BodyHTMLSurfaceRealizer()

        log.info("Configuring Body NLG Pipeline (planner = {})".format(planner))
        self.body_pipeline = NLGPipeline(self.registry, *_get_components(planner=planner))
        self.headline_pipeline = NLGPipeline(self.registry, *_get_components(headline=True, planner=planner))

    T = TypeVar("T")

    def _get_cached_or_compute(
        self, cache: str, compute: Callable[..., T], force_cache_refresh: bool = False, relative_path: bool = True
    ) -> T:  # noqa: F821 -- Needed until https://github.com/PyCQA/pyflakes/issues/427 reaches a release
        if relative_path:
            cache = os.path.abspath(os.path.join(os.path.dirname(__file__), cache))
        if force_cache_refresh:
            log.info("force_cache_refresh is True, deleting previous cache from {}".format(cache))
            if os.path.exists(cache):
                os.remove(cache)
        if not os.path.exists(cache):
            log.info("No cache at {}, computing".format(cache))
            result = compute()
            with gzip.open(cache, "wb") as f:
                pickle.dump(result, f)
            return result
        else:
            log.info("Found cache at {}, decompressing and loading".format(cache))
            with gzip.open(cache, "rb") as f:
                return pickle.load(f)

    def _load_templates(self) -> Dict[str, List[Template]]:
        log.info("Loading templates")
        templates: Dict[str, List[Template]] = defaultdict(list)
        for resource in self.resources:
            for language, new_templates in read_templates(resource.templates)[0].items():
                templates[language].extend(new_templates)

        log.debug("Templates:")
        for lang, lang_templates in templates.items():
            log.debug("\t{}".format(lang))
            for templ in lang_templates:
                log.debug("\t\t{}".format(templ))
        return templates

    def get_locations(self, dataset: str) -> List[str]:
        return list(self.registry.get("{}-data".format(dataset)).all()["location"].unique()) + ["all"]

    def get_datasets(self, language: Optional[str] = None) -> List[str]:
        return list(
            {
                dataset
                for resource in self.resources
                for dataset in resource.supported_data
                if (language is None or resource.supports(language, dataset)) and dataset in self.datasets
            }
        )

    def get_languages(self):
        return list({language for resource in self.resources for language in resource.supported_languages})

    def run_pipeline(
        self, language: str, dataset: str, location: str, location_type: str, previous_location: Optional[str]
    ) -> Tuple[str, str]:
        log.info("Running headline NLG pipeline")
        try:
            headline_lang = "{}-head".format(language)
            headline = self.headline_pipeline.run((location, location_type, dataset, previous_location), headline_lang,)
            log.info("Headline pipeline complete")
        except Exception as ex:
            headline = location
            log.error("%s", ex)

        # TODO: Figure out what DATA is supposed to be here?!
        log.info(
            "Running Body NLG pipeline: "
            "language={}, dataset={}, location={}, location_type={}, previous_location={}".format(
                language, dataset, location, location_type, previous_location
            )
        )
        try:
            body = self.body_pipeline.run((location, location_type, dataset, previous_location), language)
            log.info("Body pipeline complete")
        except NoMessagesForSelectionException:
            log.error("User selection returned no messages")
            body = ERRORS.get(language, {}).get(
                "no-messages-for-selection", "Something went wrong. Please try again later",
            )
        except Exception as ex:
            log.error("%s", ex)
            body = ERRORS.get(language, {}).get("general-error", "Something went wrong. Please try again later")

        return headline, body

    def _set_seed(self, seed_val: Optional[int] = None) -> None:
        log.info("Selecting seed for NLG pipeline")
        if not seed_val:
            seed_val = randint(1, 10000000)
            log.info("No preset seed, using random seed {}".format(seed_val))
        else:
            log.info("Using preset seed {}".format(seed_val))
        self.registry.register("seed", seed_val)