示例#1
0
    def preprocess(self):
        ''' Do any data preprocessing if needed '''
        #pdb.set_trace()
        if (all(os.path.exists(p) for p in self.data_paths)
                and all(os.path.exists(p) for p in self.vocab_paths)):
            return

        if not os.path.exists(self.preprocess_directory):
            os.makedirs(self.preprocess_directory)

        self.download_and_extract()
        self.preprocess_raw()

        # Make sure we have loaded the vocab
        self.load_vocab(preprocessing=True)

        split_filename = type(self).SPLITS[self.split]
        self.preprocess_bpe(split_filename)

        if self.annotation in (TextAnnotation.PARSE_SPANS,
                               TextAnnotation.CONSTITUENCY_PARSE):
            base_annotation_id = len(self.id2token)
            for filename in type(self).SPLITS.values():
                self.preprocess_parse(filename)

            if not os.path.exists(self.constituent_vocab_path):
                with Open(self.constituent_vocab_path, 'wt') as file:
                    file.write('\n'.join([
                        self.id2token[annotation_id]
                        for annotation_id in range(base_annotation_id,
                                                   len(self.id2token))
                    ]))
示例#2
0
    def load_vocab(self, preprocessing=False):
        ''' Return the data loader for the dataset '''
        if not os.path.exists(self.base_vocab_path):
            print('Cannot find the vocab file!')
            exit(1)

        with Open(self.base_vocab_path, 'rt') as vocab_file:
            self.token2id = {}
            self.id2token = []
            for token in vocab_file.read().split('\n'):
                self.token2id[token] = len(self.id2token)
                self.id2token.append(token)

        super(AnnotatedTextDataset, self).load_vocab(preprocessing)
        if preprocessing or self.annotation is TextAnnotation.NONE:
            return

        if self.annotation is TextAnnotation.CONSTITUENCY_PARSE:
            if not os.path.exists(self.annotation_vocab_path):
                print('Cannot find the annotation vocab file!')
                exit(1)

            with Open(self.annotation_vocab_path, 'rt') as vocab_file:
                for token in vocab_file.read().split('\n'):
                    self.token2id[token] = len(self.id2token)
                    self.id2token.append(token)
        elif self.annotation is TextAnnotation.PARSE_SPANS:
            for i in range(self.config.span):
                token = f'<SPAN{i + 1}>'
                self.token2id[token] = len(self.id2token)
                self.id2token.append(token)

        self.token2id[MASKED] = len(self.id2token)
        self.id2token.append(MASKED)

        # Need to cache off the segmenters as the BPE loading is slow. We do
        # not want that overhead for each subprocess we create in the dataloaders.
        bpe_path = os.path.join(self.preprocess_directory, 'bpe.32000')
        self.segmenters = [
            preprocess.ParseSegmenter(bpe_path, span, self.config.max_span,
                                      self.config.randomize_chunks)
            for span in range(1, self.config.span + 1)
        ]
示例#3
0
    def load_text(self):
        ''' Load the translations '''
        if not all(os.path.exists(p) for p in self.data_paths):
            print('Cannot find the processed translations!')
            exit(1)

        with ExitStack() as stack:
            base_data_file = stack.enter_context(
                Open(self.base_data_path, 'rb'))

            while True:
                if self.swap:
                    source_key = 'target'
                    target_key = 'input'
                else:
                    source_key = 'input'
                    target_key = 'target'

                example = {}
                example['input'] = array.array('H')
                example['target'] = array.array('H')

                # prepend the start of sentence token to the target
                example['target'].append(self.sos_idx)

                source_sentence_len = base_data_file.read(8)
                if not source_sentence_len:
                    break

                source_sentence_len, = struct.unpack('Q', source_sentence_len)
                example[source_key].fromstring(
                    base_data_file.read(source_sentence_len))

                target_sentence_len = base_data_file.read(8)
                if not target_sentence_len:
                    print(
                        'Unexpected end of file while trying to read a de sentence!'
                    )
                    exit(1)

                target_sentence_len, = struct.unpack('Q', target_sentence_len)
                example[target_key].frombytes(
                    base_data_file.read(target_sentence_len))

                # append the end of sentence token to the target
                example['target'].append(self.eos_idx)

                if example == {}:
                    return
                self.add_datum(example)
示例#4
0
    def load_vocab(self, preprocessing=False):
        ''' Return the data loader for the dataset '''
        if not os.path.exists(self.base_vocab_path):
            print('Cannot find the vocab file!')
            exit(1)

        with Open(self.base_vocab_path, 'rt') as vocab_file:
            self.token2id = {}
            self.id2token = []
            for token in vocab_file.read().split('\n'):
                self.token2id[token] = len(self.id2token)
                self.id2token.append(token)

        super(AnnotatedTextDataset, self).load_vocab(preprocessing)
示例#5
0
    def preprocess_bpe(self, filename):
        ''' Preprocess the BPE data '''
        tokenized_bpe_path = os.path.join(self.preprocess_directory,
                                          f'{filename}.bpe.32000')

        target_path = f'{tokenized_bpe_path}.{self.target_language}'
        source_path = f'{tokenized_bpe_path}.{self.source_language}'
        processed_path = f'{tokenized_bpe_path}.bin'

        if os.path.exists(processed_path):
            return

        with ExitStack() as stack:
            source_file = stack.enter_context(Open(source_path, 'rt'))
            target_file = stack.enter_context(Open(target_path, 'rt'))

            def encode_sentence(line):
                ''' Helper function that encodes a sentence '''
                sentence = array.array('H')
                sentence.extend(
                    (self.token2id[token] for token in line.split()))

                byte_rep = sentence.tostring()
                byte_len = len(byte_rep)
                return struct.pack('Q{}s'.format(byte_len), byte_len, byte_rep)

            out_file = stack.enter_context(tempfile.NamedTemporaryFile())
            for source_line, target_line in zip(source_file, target_file):
                source_sentence = encode_sentence(source_line)
                target_sentence = encode_sentence(target_line)

                out_file.write(source_sentence)
                out_file.write(target_sentence)

            out_file.flush()
            shutil.copy(out_file.name, f'{processed_path}.incomplete')
            os.rename(f'{processed_path}.incomplete', processed_path)
示例#6
0
    def preprocess_raw(self):
        ''' Tokenize/bpe encode the raw text '''
        def is_xml(filename):
            ''' Determine if a file is XML formatted '''
            return filename.endswith('.sgm') or filename.endswith('.xml')

        def filter_lines(in_file, basename):
            ''' Scan the file for any filtered lines '''
            filtered = set()
            xml = is_xml(basename)
            for i, line in enumerate(in_file):
                if not self.preprocess_raw_line(line, xml=xml):
                    filtered.add(i)

            return filtered

        def merge(basename, in_file, out_file, filtered=None):
            ''' Tokenize the passed in file and write it to the designated file '''
            filtered = filtered or set()
            xml = is_xml(basename)
            for i, line in enumerate(in_file):
                if i in filtered:
                    continue

                processed_line = self.preprocess_raw_line(line, xml=xml)
                out_file.write(processed_line + '\n')

        # First, clean-up any incomplete preprocessing files
        for path in glob.glob(
                os.path.join(self.preprocess_directory, '*.incomplete')):
            os.remove(os.path.join(self.preprocess_directory, path))

        bpe_code_path = os.path.join(self.preprocess_directory, 'bpe.32000')
        if not os.path.exists(bpe_code_path):
            for split, file_pairs in type(self).RAW_SPLITS.items():
                for pair in file_pairs:
                    # First determine which lines must be skipped in both files, since the files are
                    # a parallel corpora.
                    filtered = set()
                    for filename, lang in zip(pair, type(self).LANGUAGE_PAIR):
                        in_path = os.path.join(self.preprocess_directory,
                                               filename)
                        with ExitStack() as stack:
                            in_file = stack.enter_context(Open(in_path, 'rt'))
                            filtered.update(
                                filter_lines(in_file,
                                             os.path.basename(filename)))

                    for filename, lang in zip(pair, type(self).LANGUAGE_PAIR):
                        basename = os.path.basename(filename)
                        in_path = os.path.join(self.preprocess_directory,
                                               filename)
                        split_path = os.path.join(self.preprocess_directory,
                                                  f'{split}.{lang}')

                        if os.path.exists(split_path):
                            continue

                        with ExitStack() as stack:
                            out_path = f'{split_path}.incomplete'
                            in_file = stack.enter_context(Open(in_path, 'rt'))
                            out_file = stack.enter_context(Open(
                                out_path, 'at'))

                            merge(basename, in_file, out_file, filtered)

            word_counts = Counter()
            for split in type(self).RAW_SPLITS:
                for lang in type(self).LANGUAGE_PAIR:
                    try:
                        split_path = os.path.join(self.preprocess_directory,
                                                  f'{split}.{lang}')
                        os.rename(f'{split_path}.incomplete', split_path)
                    except FileNotFoundError:
                        # This can happen if the preprocessing is interrupted
                        pass

                    tokenized_path = os.path.join(self.preprocess_directory,
                                                  f'{split}.tok.{lang}')
                    word_counts.update(
                        preprocess.tokenize(split_path, tokenized_path,
                                            self.preprocess_buffer_size))

            print('Learning BPE')
            preprocess.learn_bpe(bpe_code_path, word_counts.items())

        vocab_path = os.path.join(self.preprocess_directory, 'vocab.bpe.32000')
        if not os.path.exists(vocab_path):
            vocab = set()
            for split in type(self).RAW_SPLITS:
                for lang in type(self).LANGUAGE_PAIR:
                    in_path = os.path.join(self.preprocess_directory,
                                           f'{split}.tok.{lang}')
                    bpe_path = os.path.join(self.preprocess_directory,
                                            f'{split}.tok.bpe.32000.{lang}')

                    vocab.update(
                        preprocess.apply_bpe(bpe_code_path, in_path, bpe_path,
                                             self.preprocess_buffer_size))

            vocab_path = os.path.join(self.preprocess_directory,
                                      'vocab.bpe.32000')
            incomplete_vocab_path = f'{vocab_path}.incomplete'
            with Open(incomplete_vocab_path, 'wt') as vocab_file:
                vocab_file.writelines('\n'.join(
                    [word for word in sorted(vocab)]))
            os.rename(incomplete_vocab_path, vocab_path)