def _fit_model_with_scores(topic_model: TopicModel, dataset: Dataset, scores: List[BaseScore] = None, num_fit_iterations: int = 1): if scores is not None: for score in scores: score._attach(topic_model) topic_model._fit(dataset.get_batch_vectorizer(), num_iterations=num_fit_iterations)
def _train_models( self, text_collection: VowpalWabbitTextCollection, min_df_rate: float, max_df_rate: float, num_topics_to_skip: List[int] = None) -> None: modalities_to_use = list(text_collection._modalities.keys()) main_modality = text_collection._main_modality numbers_of_topics = list(range( self._min_num_topics, self._max_num_topics + 1, self._num_topics_interval)) if num_topics_to_skip is not None: numbers_of_topics = [ n for n in numbers_of_topics if n not in num_topics_to_skip ] num_topics_for_message = ', '.join( [str(n) for n in numbers_of_topics[:10]] ) print(f'\n Folder for models saving: "{self._models_folder_path}"') print( f'Training models for {len(numbers_of_topics)}' f' numbers of topics: {num_topics_for_message}...' ) for num_topics in tqdm( numbers_of_topics, total=len(numbers_of_topics), file=sys.stdout): os.makedirs( self._folder_path_num_topics(num_topics) ) subsample_data_paths = self._get_dataset_subsample_file_paths() for subsample_number, data_path in tqdm( enumerate(subsample_data_paths), total=len(subsample_data_paths), file=sys.stdout): dataset = Dataset(data_path=data_path) dictionary = dataset.get_dictionary() dictionary.filter( min_df_rate=min_df_rate, max_df_rate=max_df_rate, ) artm_model = init_model_from_family( family=self._model_family, dataset=dataset, modalities_to_use=modalities_to_use, main_modality=main_modality, num_topics=num_topics, seed=self._model_seed, num_processors=self._model_num_processors, model_params=self._model_params, ) topic_model = TopicModel(artm_model) topic_model._fit( dataset_trainable=dataset.get_batch_vectorizer(), num_iterations=self._num_fit_iterations, ) model_save_path = self._folder_path_model(num_topics, subsample_number) topic_model.save( model_save_path=model_save_path, phi=True, theta=False, )