def preprocess(self): ''' Do any data preprocessing if needed ''' #pdb.set_trace() if (all(os.path.exists(p) for p in self.data_paths) and all(os.path.exists(p) for p in self.vocab_paths)): return if not os.path.exists(self.preprocess_directory): os.makedirs(self.preprocess_directory) self.download_and_extract() self.preprocess_raw() # Make sure we have loaded the vocab self.load_vocab(preprocessing=True) split_filename = type(self).SPLITS[self.split] self.preprocess_bpe(split_filename) if self.annotation in (TextAnnotation.PARSE_SPANS, TextAnnotation.CONSTITUENCY_PARSE): base_annotation_id = len(self.id2token) for filename in type(self).SPLITS.values(): self.preprocess_parse(filename) if not os.path.exists(self.constituent_vocab_path): with Open(self.constituent_vocab_path, 'wt') as file: file.write('\n'.join([ self.id2token[annotation_id] for annotation_id in range(base_annotation_id, len(self.id2token)) ]))
def load_vocab(self, preprocessing=False): ''' Return the data loader for the dataset ''' if not os.path.exists(self.base_vocab_path): print('Cannot find the vocab file!') exit(1) with Open(self.base_vocab_path, 'rt') as vocab_file: self.token2id = {} self.id2token = [] for token in vocab_file.read().split('\n'): self.token2id[token] = len(self.id2token) self.id2token.append(token) super(AnnotatedTextDataset, self).load_vocab(preprocessing) if preprocessing or self.annotation is TextAnnotation.NONE: return if self.annotation is TextAnnotation.CONSTITUENCY_PARSE: if not os.path.exists(self.annotation_vocab_path): print('Cannot find the annotation vocab file!') exit(1) with Open(self.annotation_vocab_path, 'rt') as vocab_file: for token in vocab_file.read().split('\n'): self.token2id[token] = len(self.id2token) self.id2token.append(token) elif self.annotation is TextAnnotation.PARSE_SPANS: for i in range(self.config.span): token = f'<SPAN{i + 1}>' self.token2id[token] = len(self.id2token) self.id2token.append(token) self.token2id[MASKED] = len(self.id2token) self.id2token.append(MASKED) # Need to cache off the segmenters as the BPE loading is slow. We do # not want that overhead for each subprocess we create in the dataloaders. bpe_path = os.path.join(self.preprocess_directory, 'bpe.32000') self.segmenters = [ preprocess.ParseSegmenter(bpe_path, span, self.config.max_span, self.config.randomize_chunks) for span in range(1, self.config.span + 1) ]
def load_text(self): ''' Load the translations ''' if not all(os.path.exists(p) for p in self.data_paths): print('Cannot find the processed translations!') exit(1) with ExitStack() as stack: base_data_file = stack.enter_context( Open(self.base_data_path, 'rb')) while True: if self.swap: source_key = 'target' target_key = 'input' else: source_key = 'input' target_key = 'target' example = {} example['input'] = array.array('H') example['target'] = array.array('H') # prepend the start of sentence token to the target example['target'].append(self.sos_idx) source_sentence_len = base_data_file.read(8) if not source_sentence_len: break source_sentence_len, = struct.unpack('Q', source_sentence_len) example[source_key].fromstring( base_data_file.read(source_sentence_len)) target_sentence_len = base_data_file.read(8) if not target_sentence_len: print( 'Unexpected end of file while trying to read a de sentence!' ) exit(1) target_sentence_len, = struct.unpack('Q', target_sentence_len) example[target_key].frombytes( base_data_file.read(target_sentence_len)) # append the end of sentence token to the target example['target'].append(self.eos_idx) if example == {}: return self.add_datum(example)
def load_vocab(self, preprocessing=False): ''' Return the data loader for the dataset ''' if not os.path.exists(self.base_vocab_path): print('Cannot find the vocab file!') exit(1) with Open(self.base_vocab_path, 'rt') as vocab_file: self.token2id = {} self.id2token = [] for token in vocab_file.read().split('\n'): self.token2id[token] = len(self.id2token) self.id2token.append(token) super(AnnotatedTextDataset, self).load_vocab(preprocessing)
def preprocess_bpe(self, filename): ''' Preprocess the BPE data ''' tokenized_bpe_path = os.path.join(self.preprocess_directory, f'{filename}.bpe.32000') target_path = f'{tokenized_bpe_path}.{self.target_language}' source_path = f'{tokenized_bpe_path}.{self.source_language}' processed_path = f'{tokenized_bpe_path}.bin' if os.path.exists(processed_path): return with ExitStack() as stack: source_file = stack.enter_context(Open(source_path, 'rt')) target_file = stack.enter_context(Open(target_path, 'rt')) def encode_sentence(line): ''' Helper function that encodes a sentence ''' sentence = array.array('H') sentence.extend( (self.token2id[token] for token in line.split())) byte_rep = sentence.tostring() byte_len = len(byte_rep) return struct.pack('Q{}s'.format(byte_len), byte_len, byte_rep) out_file = stack.enter_context(tempfile.NamedTemporaryFile()) for source_line, target_line in zip(source_file, target_file): source_sentence = encode_sentence(source_line) target_sentence = encode_sentence(target_line) out_file.write(source_sentence) out_file.write(target_sentence) out_file.flush() shutil.copy(out_file.name, f'{processed_path}.incomplete') os.rename(f'{processed_path}.incomplete', processed_path)
def preprocess_raw(self): ''' Tokenize/bpe encode the raw text ''' def is_xml(filename): ''' Determine if a file is XML formatted ''' return filename.endswith('.sgm') or filename.endswith('.xml') def filter_lines(in_file, basename): ''' Scan the file for any filtered lines ''' filtered = set() xml = is_xml(basename) for i, line in enumerate(in_file): if not self.preprocess_raw_line(line, xml=xml): filtered.add(i) return filtered def merge(basename, in_file, out_file, filtered=None): ''' Tokenize the passed in file and write it to the designated file ''' filtered = filtered or set() xml = is_xml(basename) for i, line in enumerate(in_file): if i in filtered: continue processed_line = self.preprocess_raw_line(line, xml=xml) out_file.write(processed_line + '\n') # First, clean-up any incomplete preprocessing files for path in glob.glob( os.path.join(self.preprocess_directory, '*.incomplete')): os.remove(os.path.join(self.preprocess_directory, path)) bpe_code_path = os.path.join(self.preprocess_directory, 'bpe.32000') if not os.path.exists(bpe_code_path): for split, file_pairs in type(self).RAW_SPLITS.items(): for pair in file_pairs: # First determine which lines must be skipped in both files, since the files are # a parallel corpora. filtered = set() for filename, lang in zip(pair, type(self).LANGUAGE_PAIR): in_path = os.path.join(self.preprocess_directory, filename) with ExitStack() as stack: in_file = stack.enter_context(Open(in_path, 'rt')) filtered.update( filter_lines(in_file, os.path.basename(filename))) for filename, lang in zip(pair, type(self).LANGUAGE_PAIR): basename = os.path.basename(filename) in_path = os.path.join(self.preprocess_directory, filename) split_path = os.path.join(self.preprocess_directory, f'{split}.{lang}') if os.path.exists(split_path): continue with ExitStack() as stack: out_path = f'{split_path}.incomplete' in_file = stack.enter_context(Open(in_path, 'rt')) out_file = stack.enter_context(Open( out_path, 'at')) merge(basename, in_file, out_file, filtered) word_counts = Counter() for split in type(self).RAW_SPLITS: for lang in type(self).LANGUAGE_PAIR: try: split_path = os.path.join(self.preprocess_directory, f'{split}.{lang}') os.rename(f'{split_path}.incomplete', split_path) except FileNotFoundError: # This can happen if the preprocessing is interrupted pass tokenized_path = os.path.join(self.preprocess_directory, f'{split}.tok.{lang}') word_counts.update( preprocess.tokenize(split_path, tokenized_path, self.preprocess_buffer_size)) print('Learning BPE') preprocess.learn_bpe(bpe_code_path, word_counts.items()) vocab_path = os.path.join(self.preprocess_directory, 'vocab.bpe.32000') if not os.path.exists(vocab_path): vocab = set() for split in type(self).RAW_SPLITS: for lang in type(self).LANGUAGE_PAIR: in_path = os.path.join(self.preprocess_directory, f'{split}.tok.{lang}') bpe_path = os.path.join(self.preprocess_directory, f'{split}.tok.bpe.32000.{lang}') vocab.update( preprocess.apply_bpe(bpe_code_path, in_path, bpe_path, self.preprocess_buffer_size)) vocab_path = os.path.join(self.preprocess_directory, 'vocab.bpe.32000') incomplete_vocab_path = f'{vocab_path}.incomplete' with Open(incomplete_vocab_path, 'wt') as vocab_file: vocab_file.writelines('\n'.join( [word for word in sorted(vocab)])) os.rename(incomplete_vocab_path, vocab_path)