def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): """Loads a checkpoint to CPU (with upgrading for backward compatibility). If doing single-GPU training or if the checkpoint is only being loaded by at most one process on each node (current default behavior is for only rank 0 to read the checkpoint from disk), load_on_all_ranks should be False to avoid errors from torch.distributed not having been initialized or torch.distributed.barrier() hanging. If all processes on each node may be loading the checkpoint simultaneously, load_on_all_ranks should be set to True to avoid I/O conflicts. There's currently no support for > 1 but < all processes loading the checkpoint on each node. """ local_path = PathManager.get_local_path(path) # The locally cached file returned by get_local_path() may be stale for # remote files that are periodically updated/overwritten (ex: # checkpoint_last.pt) - so we remove the local copy, sync across processes # (if needed), and then download a fresh copy. if local_path != path and PathManager.path_requires_pathmanager(path): try: os.remove(local_path) except FileNotFoundError: # With potentially multiple processes removing the same file, the # file being missing is benign (missing_ok isn't available until # Python 3.8). pass if load_on_all_ranks: torch.distributed.barrier() local_path = PathManager.get_local_path(path) with open(local_path, "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None: # hack to be able to set Namespace in dict config. this should be removed when we update to newer # omegaconf version that supports object flags, or when we migrate all existing models from omegaconf import _utils old_primitive = _utils.is_primitive_type _utils.is_primitive_type = lambda _: True state["cfg"] = OmegaConf.create(state["cfg"]) _utils.is_primitive_type = old_primitive OmegaConf.set_struct(state["cfg"], True) if arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def get_indexed_dataset_to_local(path): local_index_path = PathManager.get_local_path(index_file_path(path)) local_data_path = PathManager.get_local_path(data_file_path(path)) assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), \ "PathManager.get_local_path does not return files with expected patterns: " \ f"{local_index_path} and {local_data_path}" local_path = local_data_path[:-4] # stripping surfix ".bin" assert local_path == local_index_path[:-4] # stripping surfix ".idx" return local_path
def load_langs(cls, args, **kwargs): if args.lang_dict and args.langs: raise ValueError( "--langs and --lang-dict can not both be specified") if args.lang_dict is None and args.langs is None: logger.warning( "External language dictionary is not provided; " "use lang-pairs to infer the set of supported languages. " "The language ordering is not stable which might cause " "misalignment in pretraining and finetuning.") # infer from lang_pairs as it is langs = list({ x for lang_pair in args.lang_pairs for x in lang_pair.split("-") }) langs = sorted(langs) logger.info(f"inferred language list: {langs}") elif args.lang_dict: with open(PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8") as f: langs = [ lang.strip() for lang in f.readlines() if lang.strip() ] logger.info( f"loaded language list from {args.lang_dict} as they are ordered in file" ) elif args.langs: langs = args.langs logger.info( f"parsed the language list as they are ordered in the option: {langs}" ) return langs
def add_file_to_dictionary(filename, dict, tokenize, num_workers): def merge_result(counter): for w, c in sorted(counter.items()): dict.add_symbol(w, c) local_file = PathManager.get_local_path(filename) offsets = find_offsets(local_file, num_workers) if num_workers > 1: chunks = zip(offsets, offsets[1:]) pool = Pool(processes=num_workers) results = [] for (start_offset, end_offset) in chunks: results.append( pool.apply_async( Dictionary._add_file_to_dictionary_single_worker, ( local_file, tokenize, dict.eos_word, start_offset, end_offset, ), ) ) pool.close() pool.join() for r in results: merge_result(r.get()) else: merge_result( Dictionary._add_file_to_dictionary_single_worker( local_file, tokenize, dict.eos_word, offsets[0], offsets[1] ) )
def _binarize_file_chunk( binarizer: Binarizer, filename: str, offset_start: int, offset_end: int, output_prefix: str, dataset_impl: str, vocab_size=None, ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary) """ creates a dataset builder and append binarized items to it. This function does not finalize the builder, this is useful if you want to do other things with your bin file like appending/merging other files """ bin_file = indexed_dataset.data_file_path(output_prefix) ds = indexed_dataset.make_builder( bin_file, impl=dataset_impl, vocab_size=vocab_size, ) summary = BinarizeSummary() with Chunker( PathManager.get_local_path(filename), offset_start, offset_end ) as line_iterator: for line in line_iterator: ds.add_item(binarizer.binarize_line(line, summary)) return ds, summary
def _add_file_to_dictionary_single_worker( filename, tokenize, eos_word, worker_id=0, num_workers=1 ): counter = Counter() with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: size = os.fstat(f.fileno()).st_size chunk_size = size // num_workers offset = worker_id * chunk_size end = offset + chunk_size f.seek(offset) if offset > 0: safe_readline(f) # drop first incomplete line line = f.readline() while line: for word in tokenize(line): counter.update([word]) counter.update([eos_word]) # f.tell() returns only an opaque number which can # return to the position in the file via f.seek() # and does not necessarily represent a byte position # in the file. However, f.tell() is faithful to the # byte position _most of the time_. Thus we can just # check against the file size to prevent early exit. if f.tell() > end and f.tell() < size: break line = f.readline() return counter
def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1, L=None): # This method is modified as it takes first L words in each line. counter = Counter() with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: size = os.fstat(f.fileno()).st_size chunk_size = size // num_workers offset = worker_id * chunk_size end = offset + chunk_size f.seek(offset) if offset > 0: safe_readline(f) # drop first incomplete line line = f.readline() while line: for word in tokenize(line[:L]): counter.update([word]) counter.update([eos_word]) if f.tell() > end: break line = f.readline() return counter
def binarize( filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, offset=0, end=-1, already_numberized=False, ) -> Dict[str, int]: nseq, ntok = 0, 0 replaced = Counter() def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: # f.tell() does not always give the byte position in the file # sometimes it skips to a very large number # it is unlikely that through a normal read we go from # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely # that the procedure breaks by the undeterministic behavior of # f.tell() if end > 0 and f.tell() > end and f.tell() < end + 2**32: break if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] if reverse_order: id_list.reverse() if append_eos: id_list.append(dict.eos()) ids = torch.IntTensor(id_list) else: ids = dict.encode_line( line=line, line_tokenizer=tokenize, add_if_not_exist=False, consumer=replaced_consumer, append_eos=append_eos, reverse_order=reverse_order, ) nseq += 1 ntok += len(ids) consumer(ids) line = f.readline() return { "nseq": nseq, "nunk": sum(replaced.values()), "ntok": ntok, "replaced": replaced, }
def binarize( filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, offset=0, end=-1, already_numberized=False, ): nseq, ntok = 0, 0 replaced = Counter() def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: if end > 0 and f.tell() > end: break if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] if reverse_order: id_list.reverse() if append_eos: id_list.append(dict.eos()) ids = torch.IntTensor(id_list) else: ids = dict.encode_line( line=line, line_tokenizer=tokenize, add_if_not_exist=False, consumer=replaced_consumer, append_eos=append_eos, reverse_order=reverse_order, ) nseq += 1 ntok += len(ids) consumer(ids) line = f.readline() return { "nseq": nseq, "nunk": sum(replaced.values()), "ntok": ntok, "replaced": replaced, }
def find_offsets(filename, num_chunks): with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: size = os.fstat(f.fileno()).st_size chunk_size = size // num_chunks offsets = [0 for _ in range(num_chunks + 1)] for i in range(1, num_chunks): f.seek(chunk_size * i) safe_readline(f) offsets[i] = f.tell() return offsets
def cached_path_from_pm(url_or_filename): """ Tries to cache the specified URL using PathManager class. Returns the cached path if success otherwise failure. """ try: from fairseq.file_io import PathManager local_path = PathManager.get_local_path(url_or_filename) return local_path except Exception: return None
def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" with open(PathManager.get_local_path(path), "rb") as f: state = torch.load( f, map_location=lambda s, l: default_restore_location(s, "cpu")) args = state["args"] if arg_overrides is not None: for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) state = _upgrade_state_dict(state) return state
def binarize_tag(filename, consumer, offset=0, end=-1) -> Dict[str, int]: nseq = 0 with open(PathManager.get_local_path(filename), "r") as f: f.seek(offset) line = safe_readline(f) while line: if end > 0 and f.tell() > end: break ids = torch.IntTensor( [int(num.strip()) + 1 for num in line.split(',')]) nseq += 1 consumer(ids) line = f.readline() return {"nseq": nseq}
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1) -> Dict[str, int]: nseq = 0 with Chunker(PathManager.get_local_path(filename), offset, end) as line_iterator: for line in line_iterator: ids = alignment_parser(line) nseq += 1 consumer(ids) return {"nseq": nseq}
def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" with open(PathManager.get_local_path(path), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def binarize( filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, offset=0, end=-1, already_numberized=False, ) -> Dict[str, int]: nseq, ntok = 0, 0 replaced = Counter() def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) with Chunker(PathManager.get_local_path(filename), offset, end) as line_iterator: for line in line_iterator: if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] if reverse_order: id_list.reverse() if append_eos: id_list.append(dict.eos()) ids = torch.IntTensor(id_list) else: ids = dict.encode_line( line=line, line_tokenizer=tokenize, add_if_not_exist=False, consumer=replaced_consumer, append_eos=append_eos, reverse_order=reverse_order, ) nseq += 1 ntok += len(ids) consumer(ids) return { "nseq": nseq, "nunk": sum(replaced.values()), "ntok": ntok, "replaced": replaced, }
def add_from_file(self, f): """ Loads a pre-existing dictionary from a text file and adds its symbols to this instance. """ if isinstance(f, str): try: print("RAV MSG CATCH", f) if (f == '/home/ubuntu/project/manifest/train/dict.ltr.txt'): f = '/home/ras306/Classwork/Project/Wav2Vec/model/dict.ltr.txt' with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: self.add_from_file(fd) except FileNotFoundError as fnfe: raise fnfe except UnicodeError: raise Exception("Incorrect encoding detected in {}, please " "rebuild the dataset".format(f)) return lines = f.readlines() indices_start_line = self._load_meta(lines) for line in lines[indices_start_line:]: try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": overwrite = True line, field = line.rsplit(" ", 1) else: overwrite = False count = int(field) word = line if word in self and not overwrite: raise RuntimeError( "Duplicate word found when loading Dictionary: '{}'. " "Duplicate words can overwrite earlier ones by adding the " "#fairseq:overwrite flag at the end of the corresponding row " "in the dictionary file. If using the Camembert model, please " "download an updated copy of the model file.".format( word)) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( "Incorrect dictionary format, expected '<token> <cnt> [flags]'" )
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): nseq = 0 with open(PathManager.get_local_path(filename), "r") as f: f.seek(offset) line = safe_readline(f) while line: if end > 0 and f.tell() > end: break ids = alignment_parser(line) nseq += 1 consumer(ids) line = f.readline() return {"nseq": nseq}
def binarize_da(filename, consumer, da_mapping, offset=0, end=-1) -> Dict[str, int]: nseq = 0 num_mapping = len(da_mapping) with open(PathManager.get_local_path(filename), "r") as f: f.seek(offset) line = safe_readline(f) while line: if end > 0 and f.tell() > end: break ids = torch.IntTensor( [da_mapping.get(line.strip(), num_mapping + 1)]) nseq += 1 consumer(ids) line = f.readline() return {"nseq": nseq}
def add_ordered_utt_list(self, *args): if len(args) == 1 and isinstance( args[0], list): # aleady a list of utterance ids self.ordered_utt_list = args[0] return self.ordered_utt_list = [] for text_file in args: with open(PathManager.get_local_path(text_file), "r", encoding="utf-8") as f: one_utt_list = [line.strip().split()[0] for line in f] self.ordered_utt_list.extend(one_utt_list) if len(self.char_results): assert set(self.ordered_utt_list) == set(self.char_results.keys()) if len(self.results): assert set(self.ordered_utt_list) == set(self.results.keys()) if len(self.aligned_results): assert set(self.ordered_utt_list) == set( self.aligned_results.keys())
def load(cls, f, enable_bos=False, f_non_lang_syms=None): """Loads the dictionary from a text file with the format: ``` <symbol0> <count0> <symbol1> <count1> ... ``` Optionally add bos symbol to the dictionary Identifies the space symbol if it exists, by obtaining its index (space_index=-1 if no space symbol) Loads non_lang_syms from another text file, if it exists, with one symbol per line """ d = cls(enable_bos=enable_bos) d.add_from_file(f) d.space_index = d.indices.get(d.space_word, -1) if f_non_lang_syms is not None: assert isinstance(f_non_lang_syms, str) try: with open(PathManager.get_local_path(f_non_lang_syms), "r", encoding="utf-8") as fd: non_lang_syms = [x.rstrip() for x in fd.readlines()] except FileNotFoundError as fnfe: raise fnfe except UnicodeError: raise Exception("Incorrect encoding detected in {}, please " "rebuild the dataset".format(fd)) for sym in non_lang_syms: assert (d.index(sym) != d.unk()), "{} in {} is not in the dictionary".format( sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d
def load_checkpoint_to_cpu(path, arg_overrides=None): """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" print("RAV MSG CATCH: ", path) if (path == '/home/ubuntu/project/model/wav2vec_small.pt'): path = '/home/ras306/Classwork/Project/Wav2Vec/model/wav2vec_small.pt' with open(PathManager.get_local_path(path), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state[ "args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state[ "cfg"] is not None and arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def parse_wer_output_filter(self, wer_output_filter): if wer_output_filter: with open(PathManager.get_local_path(wer_output_filter), 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line.startswith('#!') or line == '': continue elif line.startswith('s/'): m = re.match(r's/(.+)/(.*)/g', line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) elif line.startswith('s:'): m = re.match(r's:(.+):(.*):g', line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: logger.warning( 'Unsupported pattern: "{}". Ignoring it'.format( line))
def parse_wer_output_filter(self, wer_output_filter): if wer_output_filter: with open(PathManager.get_local_path(wer_output_filter), "r", encoding="utf-8") as f: for line in f: line = line.strip() if line.startswith("#!") or line == "": continue elif line.startswith("s/"): m = re.match(r"s/(.+)/(.*)/g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) elif line.startswith("s:"): m = re.match(r"s:(.+):(.*):g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: logger.warning( "Unsupported pattern: '{}'. Ignoring it".format( line))
def _consume_file( filename: str, binarizer: Binarizer, consumer: tp.Callable[[torch.IntTensor], None], offset_start: int, offset_end: int, ) -> tp.Dict[str, int]: summary = BinarizeSummary() with Chunker( PathManager.get_local_path(filename), offset_start, offset_end ) as line_iterator: for line in line_iterator: consumer(binarizer.binarize_line(line, summary)) return { "nseq": summary.num_seq, "nunk": summary.num_replaced, "ntok": summary.num_tok, "replaced": summary.replaced, }
def upgrade_state_dict_with_infoxlm_weights( state_dict: Dict[str, Any], pretrained_infoxlm_checkpoint: str, num_layers: int, shared_cross_attn: bool = False) -> Dict[str, Any]: """ Load XLM weights into a Transformer encoder or decoder model. Args: state_dict: state dict for either TransformerEncoder or TransformerDecoder pretrained_infoxlm_checkpoint: checkpoint to load XLM weights from Raises: AssertionError: If architecture (num layers, attention heads, etc.) does not match between the current Transformer encoder or decoder and the pretrained_xlm_checkpoint """ if not os.path.exists(pretrained_infoxlm_checkpoint): raise IOError( "Model file not found: {}".format(pretrained_infoxlm_checkpoint)) # state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_infoxlm_checkpoint) with open(PathManager.get_local_path(pretrained_infoxlm_checkpoint), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) infoxlm_state_dict = state["model"] # print(state_dict.keys()) for key in infoxlm_state_dict.keys(): if 'layers' in key and int(key.split('.')[3]) > num_layers - 1: continue if not key.startswith('decoder.'): continue if 'lm_head' not in key: if 'in_proj_weight' in key: q, k, v = infoxlm_state_dict[key].chunk(3, dim=0) state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'q_proj.weight')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'k_proj.weight')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'v_proj.weight')] = v if shared_cross_attn: state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'q_proj.weight').replace( 'self_attn', 'encoder_attn')] = q state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'k_proj.weight').replace( 'self_attn', 'encoder_attn')] = k state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'v_proj.weight').replace( 'self_attn', 'encoder_attn')] = v elif 'in_proj_bias' in key: q, k, v = infoxlm_state_dict[key].chunk(3, dim=0) state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'q_proj.bias')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'k_proj.bias')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'v_proj.bias')] = v if shared_cross_attn: state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'q_proj.bias').replace( 'self_attn', 'encoder_attn')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'k_proj.bias').replace( 'self_attn', 'encoder_attn')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'v_proj.bias').replace( 'self_attn', 'encoder_attn')] = v elif 'emb_layer_norm' in key: state_dict[key.replace( 'decoder.sentence_encoder.emb_layer_norm', 'layernorm_embedding')] = infoxlm_state_dict[key] elif 'embed_positions' in key: state_dict[key.replace( 'decoder.sentence_encoder.', '')] = infoxlm_state_dict[key][:state_dict[key.replace( 'decoder.sentence_encoder.', '')].size(0)] elif 'embed_tokens' in key: state_dict[key.replace('decoder.sentence_encoder.', '')][:infoxlm_state_dict[key]. size(0)] = infoxlm_state_dict[key] else: state_dict[key.replace('decoder.sentence_encoder.', '')] = infoxlm_state_dict[key] return state_dict
def binarize( filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, offset=0, end=-1, already_numberized=False, avoid_tokenize=False, ) -> Dict[str, int]: nseq, ntok = 0, 0 replaced = Counter() def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) def replaced_consumer_from_pretrained(word, idx): if idx == dict.convert_tokens_to_ids( dict.unk_token) and word != dict.unk_token: replaced.update([word]) with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: # f.tell() does not always give the byte position in the file # sometimes it skips to a very large number # it is unlikely that through a normal read we go from # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely # that the procedure breaks by the undeterministic behavior of # f.tell() if end > 0 and f.tell() > end and f.tell() < end + 2**32: break if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] if reverse_order: id_list.reverse() if append_eos: id_list.append(dict.eos()) ids = torch.IntTensor(id_list) elif isinstance(dict, BertTokenizer) and not isinstance( dict, ElectraTokenizer): line = line.strip() line = '{} {} {}'.format('[CLS]', line, '[SEP]') if avoid_tokenize is False: tokenizedline = dict.tokenize(line) else: tokenizedline = line.strip().split() # max-len:1000000000000 # print('----------bert_max-len:' + str(dict.max_len) + '----------') # if len(tokenizedline) > dict.max_len: # tokenizedline = tokenizedline[:dict.max_len - 1] # tokenizedline.append('[SEP]') words = dict.convert_tokens_to_ids(tokenizedline) nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif isinstance(dict, BartTokenizer): line = line.strip() if avoid_tokenize is False: # extra space at the end will cause weird outputs. line = '{} {}{}'.format('<s>', line, '</s>') tokenizedline = dict.tokenize(line) else: line = '{} {} {}'.format('<s>', line, '</s>') tokenizedline = line.strip().split() # tokenizedline = dict.tokenize(line) words = dict.convert_tokens_to_ids(tokenizedline) assert len(tokenizedline) == len(words) nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif isinstance(dict, ElectraTokenizer): line = line.strip() line = '{} {} {}'.format('[CLS]', line, '[SEP]') if avoid_tokenize is False: tokenizedline = dict.tokenize(line) else: tokenizedline = line.strip().split() # max-len:1000000000000 # print('----------bert_max-len:' + str(dict.max_len) + '----------') # if len(tokenizedline) > dict.max_len: # tokenizedline = tokenizedline[:dict.max_len - 1] # tokenizedline.append('[SEP]') words = dict.convert_tokens_to_ids(tokenizedline) # # import pdb; pdb.set_trace() nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif dict is None: line = line.strip() words = line.split() words = [int(item) for item in words] nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word else: ids = dict.encode_line( line=line, line_tokenizer=tokenize, add_if_not_exist=False, consumer=replaced_consumer, append_eos=append_eos, reverse_order=reverse_order, ) nseq += 1 ntok += len(ids) consumer(ids) line = f.readline() return { "nseq": nseq, "nunk": sum(replaced.values()), "ntok": ntok, "replaced": replaced, }