コード例 #1
0
def _fit_model_with_scores(topic_model: TopicModel,
                           dataset: Dataset,
                           scores: List[BaseScore] = None,
                           num_fit_iterations: int = 1):

    if scores is not None:
        for score in scores:
            score._attach(topic_model)

    topic_model._fit(dataset.get_batch_vectorizer(),
                     num_iterations=num_fit_iterations)
    def _train_models(
            self,
            text_collection: VowpalWabbitTextCollection,
            min_df_rate: float,
            max_df_rate: float,
            num_topics_to_skip: List[int] = None) -> None:

        modalities_to_use = list(text_collection._modalities.keys())
        main_modality = text_collection._main_modality

        numbers_of_topics = list(range(
            self._min_num_topics,
            self._max_num_topics + 1,
            self._num_topics_interval))

        if num_topics_to_skip is not None:
            numbers_of_topics = [
                n for n in numbers_of_topics if n not in num_topics_to_skip
            ]

        num_topics_for_message = ', '.join(
            [str(n) for n in numbers_of_topics[:10]]
        )

        print(f'\n Folder for models saving: "{self._models_folder_path}"')
        print(
            f'Training models for {len(numbers_of_topics)}'
            f' numbers of topics: {num_topics_for_message}...'
        )

        for num_topics in tqdm(
                numbers_of_topics,
                total=len(numbers_of_topics),
                file=sys.stdout):

            os.makedirs(
                self._folder_path_num_topics(num_topics)
            )

            subsample_data_paths = self._get_dataset_subsample_file_paths()

            for subsample_number, data_path in tqdm(
                    enumerate(subsample_data_paths),
                    total=len(subsample_data_paths),
                    file=sys.stdout):

                dataset = Dataset(data_path=data_path)

                dictionary = dataset.get_dictionary()
                dictionary.filter(
                    min_df_rate=min_df_rate,
                    max_df_rate=max_df_rate,
                )

                artm_model = init_model_from_family(
                    family=self._model_family,
                    dataset=dataset,
                    modalities_to_use=modalities_to_use,
                    main_modality=main_modality,
                    num_topics=num_topics,
                    seed=self._model_seed,
                    num_processors=self._model_num_processors,
                    model_params=self._model_params,
                )
                topic_model = TopicModel(artm_model)

                topic_model._fit(
                    dataset_trainable=dataset.get_batch_vectorizer(),
                    num_iterations=self._num_fit_iterations,
                )

                model_save_path = self._folder_path_model(num_topics, subsample_number)
                topic_model.save(
                    model_save_path=model_save_path,
                    phi=True,
                    theta=False,
                )