def __init__( self, language: str, arguments_service: PostOCRArgumentsService, file_service: FileService, tokenize_service: BaseTokenizeService, run_type: RunType, **kwargs): super(NewsEyeDataset, self).__init__() self._arguments_service = arguments_service output_data_path = file_service.get_data_path() language_data_path = os.path.join( output_data_path, f'{run_type.to_str()}_language_data.pickle') if not tokenize_service.is_tokenizer_loaded(): full_data_path = os.path.join( 'data', 'ICDAR2019_POCR_competition_dataset', 'ICDAR2019_POCR_competition_full_22M_without_Finnish') vocabulary_size = tokenize_service.vocabulary_size train_spm_model(full_data_path, output_data_path, language, vocabulary_size) tokenize_service.load_tokenizer_model() if not os.path.exists(language_data_path): train_data_path = os.path.join( 'data', 'ICDAR2019_POCR_competition_dataset', 'ICDAR2019_POCR_competition_training_18M_without_Finnish') test_data_path = os.path.join( 'data', 'ICDAR2019_POCR_competition_dataset', 'ICDAR2019_POCR_competition_evaluation_4M_without_Finnish') preprocess_data(language, train_data_path, test_data_path, output_data_path, tokenize_service.tokenizer) with open(language_data_path, 'rb') as data_file: self._language_data: LanguageData = pickle.load(data_file)
def _get_language_data_path( self, file_service: FileService, run_type: RunType): output_data_path = file_service.get_data_path() language_data_path = os.path.join( output_data_path, f'{run_type.to_str()}_language_data.pickle') if not os.path.exists(language_data_path): challenge_path = file_service.get_challenge_path() full_data_path = os.path.join(challenge_path, 'full') if not os.path.exists(full_data_path) or len(os.listdir(full_data_path)) == 0: newseye_path = os.path.join('data', 'newseye') trove_path = os.path.join('data', 'trove') # ocr_download.combine_data(challenge_path, newseye_path, trove_path) # TODO Fix download pickles_path = file_service.get_pickles_path() train_data_path = file_service.get_pickles_path() preprocess_data( self._tokenize_service, self._metrics_service, self._vocabulary_service, pickles_path, full_data_path, output_data_path) return language_data_path
def _get_language_data_path(self, file_service: FileService, run_type: RunType): output_data_path = file_service.get_data_path() language_data_path = os.path.join( output_data_path, f'{run_type.to_str()}_language_data.pickle') if not os.path.exists(language_data_path): train_data_path = file_service.get_pickles_path() test_data_path = None preprocess_data(train_data_path, test_data_path, output_data_path, self._tokenize_service.tokenizer, self._vocabulary_service) return language_data_path
def __init__(self, arguments_service: NERArgumentsService, vocabulary_service: VocabularyService, file_service: FileService, tokenize_service: BaseTokenizeService, data_service: DataService, cache_service: CacheService, string_process_service: StringProcessService): super().__init__() self._arguments_service = arguments_service self._tokenize_service = tokenize_service self._file_service = file_service self._data_service = data_service self._string_process_service = string_process_service self._entity_tag_types = arguments_service.entity_tag_types self._data_version = "1.3" self.PAD_TOKEN = '[PAD]' self.START_TOKEN = '[CLS]' self.STOP_TOKEN = '[SEP]' self.pad_idx = 0 self.start_idx = 1 self.stop_idx = 2 data_path = file_service.get_data_path() language_suffix = self.get_language_suffix(arguments_service.language) train_cache_key = f'train-hipe-data-v{self._data_version}-limit-{arguments_service.train_dataset_limit_size}-{arguments_service.split_type.value}-merge-{arguments_service.merge_subwords}-replacen-{arguments_service.replace_all_numbers}' validation_cache_key = f'validation-hipe-data-v{self._data_version}-limit-{arguments_service.validation_dataset_limit_size}-{arguments_service.split_type.value}-merge-{arguments_service.merge_subwords}-replacen-{arguments_service.replace_all_numbers}' self._train_ne_collection = cache_service.get_item_from_cache( item_key=train_cache_key, callback_function=lambda: (self.preprocess_data(os.path.join( data_path, f'HIPE-data-v{self._data_version}-train-{language_suffix}.tsv' ), limit=arguments_service. train_dataset_limit_size))) self._validation_ne_collection = cache_service.get_item_from_cache( item_key=validation_cache_key, callback_function=lambda: (self.preprocess_data(os.path.join( data_path, f'HIPE-data-v{self._data_version}-dev-{language_suffix}.tsv'), limit=arguments_service. validation_dataset_limit_size))) if arguments_service.evaluate: test_cache_key = f'test-hipe-data-v{self._data_version}-{arguments_service.split_type.value}-merge-{arguments_service.merge_subwords}-replacen-{arguments_service.replace_all_numbers}' self._test_ne_collection = cache_service.get_item_from_cache( item_key=test_cache_key, callback_function=lambda: (self.preprocess_data( os.path.join( data_path, f'HIPE-data-v{self._data_version}-test-{language_suffix}.tsv' )))) self._entity_mappings = self._create_entity_mappings( self._train_ne_collection, self._validation_ne_collection) vocabulary_cache_key = f'char-vocabulary-{self._data_version}' vocabulary_data = cache_service.get_item_from_cache( item_key=vocabulary_cache_key, callback_function=lambda: self._generate_vocabulary_data( language_suffix, self._data_version)) vocabulary_service.initialize_vocabulary_data(vocabulary_data)