예제 #1
0
파일: data.py 프로젝트: thammegowda/mtdata
 def add_parts_sequential(self,
                          dir_path,
                          entries,
                          drop_noise=False,
                          compress=False,
                          desc=None,
                          fail_on_error=False):
     with pbar_man.counter(color='blue',
                           leave=False,
                           total=len(entries),
                           unit='it',
                           desc=desc,
                           autorefresh=True,
                           position=3) as pbar:
         for ent in entries:
             try:
                 n_good, n_bad = self.add_part(dir_path=dir_path,
                                               entry=ent,
                                               drop_noise=drop_noise,
                                               compress=compress)
                 if max(n_good, n_bad
                        ) >= 0:  # -1 for skipped record because it is valid
                     log.info(
                         f"{ent.did.name} : found {n_good:} segments and {n_bad:} errors"
                     )
                 pbar.update(force=True)
             except Exception as e:
                 log.error(f"Unable to add {ent.did}: {e}")
                 if fail_on_error:
                     raise e
                 msg = str(e).replace('\n', '\t')
                 self.errors_file.open('a').write(f"{ent.did}\t{msg}\n")
예제 #2
0
파일: data.py 프로젝트: thammegowda/mtdata
    def add_parts(self,
                  dir_path,
                  entries,
                  drop_noise=False,
                  compress=False,
                  desc=None,
                  fail_on_error=False):
        assert isinstance(entries, list)
        if self.n_jobs == 1:
            return self.add_parts_sequential(dir_path=dir_path,
                                             entries=entries,
                                             drop_noise=drop_noise,
                                             compress=compress,
                                             desc=desc,
                                             fail_on_error=fail_on_error)

        tasks = [
            dict(dir_path=dir_path,
                 entry=ent,
                 drop_noise=drop_noise,
                 compress=compress,
                 fail_on_error=fail_on_error) for ent in entries
        ]
        pool = Pool(self.n_jobs)
        with pbar_man.counter(color='blue',
                              leave=False,
                              total=len(entries),
                              unit='it',
                              desc=desc,
                              autorefresh=True,
                              position=3) as pbar:
            for _ in pool.imap_unordered(self.add_part_thread, tasks):
                pbar.update(force=True)
예제 #3
0
파일: data.py 프로젝트: thammegowda/mtdata
 def cat_bitexts(self, in_paths: List[Tuple[Path, Path]],
                 out_paths: Tuple[Path, Path]):
     of1, of2 = out_paths
     of1.parent.mkdir(exist_ok=True)
     of2.parent.mkdir(exist_ok=True)
     with pbar_man.counter(color='green', total=len(in_paths), unit='it', desc="Merging") as pbar, \
             IO.writer(of1) as w1, IO.writer(of2) as w2:
         for if1, if2 in in_paths:
             assert if1.exists()
             assert if2.exists()
             for seg1, seg2 in self.read_parallel(if1, if2):
                 w1.write(seg1 + '\n')
                 w2.write(seg2 + '\n')
             pbar.update()
예제 #4
0
파일: cache.py 프로젝트: thammegowda/mtdata
    def download(self, url: str, save_at: Path, timeout=(5, 10)):
        valid_flag = self.get_flag_file(save_at)
        lock_file = valid_flag.with_suffix("._lock")
        if valid_flag.exists() and save_at.exists():
            return save_at
        save_at.parent.mkdir(parents=True, exist_ok=True)

        log.info(f"Acquiring lock on {lock_file}")
        with portalocker.Lock(lock_file, 'w', timeout=FILE_LOCK_TIMEOUT) as fh:
            # check if downloaded by  other parallel process
            if valid_flag.exists() and save_at.exists():
                return save_at
            log.info(f"GET {url} → {save_at}")
            resp = requests.get(url=url,
                                allow_redirects=True,
                                headers=headers,
                                stream=True,
                                timeout=timeout)
            assert resp.status_code == 200, resp.status_code
            buf_size = 2**14
            tot_bytes = int(resp.headers.get('Content-Length', '0'))
            n_buffers = math.ceil(tot_bytes / buf_size) or None
            desc = url
            if len(desc) > 60:
                desc = desc[:30] + '...' + desc[-28:]
            with pbar_man.counter(color='green',
                                  total=tot_bytes // 2**10,
                                  unit='KiB',
                                  leave=False,
                                  position=2,
                                  desc=f"{desc}") as pbar, open(
                                      save_at, 'wb', buffering=2**24) as out:
                for chunk in resp.iter_content(chunk_size=buf_size):
                    out.write(chunk)
                    pbar.update(incr=buf_size // 2**10)
            valid_flag.touch()
            lock_file.unlink()
            return save_at
예제 #5
0
파일: data.py 프로젝트: thammegowda/mtdata
    def add_train_entries(self,
                          entries,
                          merge_train=False,
                          compress=False,
                          drop_hashes=None):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise,
                       compress=compress,
                       desc='Training sets',
                       fail_on_error=self.fail_on_error)
        if not merge_train:
            return
        lang1, lang2 = self.langs
        # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2)
        paired_files = {}
        for ent in entries:
            e1, e2 = self.get_paths(self.train_parts_dir, ent)
            _, swapped = BCP47Tag.check_compat_swap(self.langs,
                                                    ent.did.langs,
                                                    fail_on_incompat=True)
            if swapped:
                e1, e2 = e2, e1
            paired_files[str(ent.did)] = e1, e2

        log.info(f"Going to merge {len(paired_files)} files as one train file")
        compress_ext = f'.{DEF_COMPRESS}' if compress else ''
        l1_ext = f'{lang1}{compress_ext}'
        l2_ext = f'{lang2}{compress_ext}'
        of1 = self.dir / f'train.{l1_ext}'
        of2 = self.dir / f'train.{l2_ext}'
        of3 = self.dir / f'train.meta.{DEF_COMPRESS}'

        counts = dict(total=coll.defaultdict(int),
                      dupes_skips=coll.defaultdict(int),
                      test_overlap_skips=coll.defaultdict(int),
                      selected=coll.defaultdict(int))
        train_hashes = set()

        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            with pbar_man.counter(color='green',
                                  total=len(paired_files),
                                  unit='it',
                                  desc="Merging",
                                  autorefresh=False) as pbar:
                for name, (if1, if2) in paired_files.items():
                    for seg1, seg2 in self.read_parallel(if1, if2):
                        counts['total'][name] += 1
                        if self.drop_dupes or self.drop_tests:
                            hash_val = hash((seg1, seg2))
                            if drop_hashes and (hash_val in drop_hashes
                                                or hash(seg1) in drop_hashes
                                                or hash(seg2) in drop_hashes):
                                counts['test_overlap_skips'][name] += 1
                                continue
                            if self.drop_dupes:
                                if hash_val in train_hashes:
                                    counts['dupes_skips'][name] += 1
                                    continue
                                train_hashes.add(hash_val)
                        w1.write(seg1 + '\n')
                        w2.write(seg2 + '\n')
                        w3.write(name + '\n')
                        counts['selected'][name] += 1
                    pbar.update()

        stats = dict(selected=sum(counts['selected'].values()),
                     total=sum(counts['total'].values()),
                     counts=counts)

        stats_msg = json.dumps(stats, indent=2)
        log.info('Train stats:\n' + stats_msg)
        IO.write_lines(self.dir / 'train.stats.json', stats_msg)
        return counts
예제 #6
0
    def read_segs(self):
        readers = []
        if self.ext == 'opus_xces':
            preprocessing = 'xml'
            if "/raw/" in self.ent.in_paths[0]:
                preprocessing = 'raw'
            align, lang1_dir, lang2_dir = self.paths
            from mtdata.opus_xces import OpusXcesParser
            reader = OpusXcesParser.read(align,
                                         lang1_dir,
                                         lang2_dir,
                                         preprocessing=preprocessing)
            readers.append(reader)
        else:
            for p in self.paths:
                if 'tsv' in self.ext:
                    cols = (0, 1)  #extract first two columns
                    if self.ent and self.ent.cols:
                        cols = self.ent.cols
                    readers.append(self.read_tsv(p, cols=cols))
                elif 'csvwithheader' in self.ext:
                    readers.append(self.read_tsv(p, delim=',',
                                                 skipheader=True))
                elif 'raw' in self.ext or 'txt' in self.ext:
                    readers.append(self.read_plain(p))
                elif 'tmx' in self.ext:
                    from mtdata.tmx import read_tmx
                    readers.append(read_tmx(path=p, langs=self.ent.did.langs))
                elif 'sgm' in self.ext:
                    from mtdata.sgm import read_sgm
                    readers.append(read_sgm(p))
                elif 'wmt21xml' in self.ext:
                    from mtdata.sgm import read_wmt21_xml
                    readers.append(read_wmt21_xml(p))
                else:
                    raise Exception(f'Not supported {self.ext} : {p}')

        if len(readers) == 1:
            data = readers[0]
        elif self.ext == 'tmx' or self.ext == 'tsv':
            data = (rec for reader in readers
                    for rec in reader)  # flatten all readers
        elif len(readers) == 2:

            def _zip_n_check():
                for seg1, seg2 in zip_longest(*readers):
                    if seg1 is None or seg2 is None:
                        raise Exception(
                            f'{self.paths} have unequal number of segments')
                    yield seg1, seg2

            data = _zip_n_check()
        else:
            raise Exception("This is an error")
        with pbar_man.counter(color='green',
                              unit='seg',
                              leave=False,
                              desc=f"Reading {self.ent.did}") as pbar:
            for rec in data:
                yield rec
                pbar.update()