コード例 #1
0
ファイル: tmx.py プロジェクト: kpu/mtdata
def read_tmx(path: Union[Path, str], langs=None):
    """
    reads a TMX file as records
    :param path: path to .tmx file
    :param langs: (lang1, lang2) codes eg (de, en); when it is None the code tries to auto detect
    :return: stream of (text1, text2)
    """
    passes = 0
    fails = 0
    with IO.reader(path) as data:
        recs = parse_tmx(data)
        for lang_seg in recs:
            if langs is None:
                log.warning(
                    "langs not set; this could result in language mismatch")
                if len(lang_seg) == 2:
                    langs = tuple(lang_seg.keys())
                else:
                    raise Exception(
                        f"Language autodetect for TMX only supports 2 languages, but provided with {lang_seg.keys()} in TMX {path}"
                    )
            if langs[0] in lang_seg and langs[1] in lang_seg:
                yield lang_seg[langs[0]], lang_seg[langs[1]]
                passes += 1
            else:
                fails += 1
    if passes == 0:
        if fails == 0:
            raise Exception(f"Empty TMX {path}")
        raise Exception(f"Nothing for {langs[0]}--{langs[1]} in TMX {path}")
    if fails != 0:
        log.warning(
            f"Skipped {fails} entries due to language mismatch in TMX {path}")
    log.info(f"Extracted {passes} pairs from TMX {path}")
コード例 #2
0
ファイル: __init__.py プロジェクト: thammegowda/mtdata
    def load_all(self):
        from mtdata.index import (statmt, paracrawl, tilde, joshua_indian,
                                  unitednations, wikimatrix, other,
                                  neulab_tedtalks, elrc_share, ai4bharat, eu,
                                  linguatools, anuvaad)
        from mtdata.index.opus import opus_index, jw300, opus100
        subsets = [
            ('Statmt.org', statmt.load),
            ('Paracrawl', paracrawl.load),
            ('Tilde', tilde.load),
            ('JoshuaIndianCoprus', joshua_indian.load_all),
            ('UnitedNations', unitednations.load_all),
            ('OPUS', opus_index.load_all),
            # ('OPUS_JW300', jw300.load_all), # JW300 is taken down
            ('OPUS100', opus100.load_all),
            ('WikiMatrix', wikimatrix.load_all),
            ('Other', other.load_all),
            ('Neulab_TEDTalksv1', neulab_tedtalks.load_all),
            ('ELRC-SHARE', elrc_share.load_all),
            ('AI4Bharat', ai4bharat.load_all),
            ('EU', eu.load_all),
            ('LinguaTools', linguatools.load_all),
            ('Anuvaad', anuvaad.load_all),
        ]
        for name, loader in subsets:
            loader(self)

        counts = collections.defaultdict(int)
        for e in self.entries.values():
            counts[e.did.group] += 1
        items = list(sorted(counts.items(), key=lambda x: x[1], reverse=True))
        items += [('Total', len(self))]
        counts = '\n'.join([f'| {n} | {c:,}|' for n, c in items])
        log.info(f"Index status:\n{counts}")
コード例 #3
0
 def download(self, entry: Entry, save_at: Path):
     save_at.parent.mkdir(parents=True, exist_ok=True)
     log.info(f"GET: {entry.url} --> {save_at}")
     out_file = wget.download(entry.url, out=str(save_at))
     log.info(" Done.")
     assert Path(out_file).resolve() == save_at.resolve()  # saved where we asked it to save
     return save_at
コード例 #4
0
ファイル: data.py プロジェクト: kpu/mtdata
 def add_parts(self, dir_path, entries, drop_noise=False):
     for ent in entries:
         n_good, n_bad = self.add_part(dir_path=dir_path,
                                       entry=ent,
                                       drop_noise=drop_noise)
         log.info(
             f"{ent.name} : found {n_good:} segments and {n_bad:} errors")
コード例 #5
0
    def load_all(self):
        from mtdata.index import (statmt, paracrawl, tilde, literature,
                                  joshua_indian, unitednations, wikimatrix,
                                  other, neulab_tedtalks, elrc_share,
                                  ai4bharat, eu)
        from mtdata.index.opus import opus_index, jw300, opus100

        counts = {}
        subsets = [('Statmt.org', statmt.load), ('Paracrawl', paracrawl.load),
                   ('Tilde', tilde.load),
                   ('JoshuaIndianCoprus', joshua_indian.load_all),
                   ('UnitedNations', unitednations.load_all),
                   ('OPUS', opus_index.load_all),
                   ('OPUS_JW300', jw300.load_all),
                   ('OPUS100', opus100.load_all),
                   ('WikiMatrix', wikimatrix.load_all),
                   ('Other', other.load_all),
                   ('Neulab_TEDTalksv1', neulab_tedtalks.load_all),
                   ('ELRC-SHARE', elrc_share.load_all),
                   ('AI4Bharat', ai4bharat.load_all), ('EU', eu.load_all)]
        for name, loader in subsets:
            n = len(self)
            loader(self)
            counts[name] = len(self) - n
        items = list(sorted(counts.items(), key=lambda x: x[1], reverse=True))
        items += [('Total', len(self))]
        counts = '  '.join([f'{n}:{c:,}' for n, c in items])
        log.info(f"Index status: {counts}")
        literature.load(self)
コード例 #6
0
ファイル: data.py プロジェクト: kpu/mtdata
    def add_train_entries(self, entries, merge_train=False):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise)
        if not merge_train:
            return
        # merge
        l1, l2 = self.langs
        l1_files = list(self.train_parts_dir.glob(f"*.{l1}"))
        assert l1_files and len(l1_files) >= len(entries)

        l2_files = [l1_f.with_suffix(f".{l2}") for l1_f in l1_files]
        assert all(l2_f.exists() for l2_f in l2_files)
        log.info(f"Going to merge {len(l1_files)} files as one train file")
        counts = coll.defaultdict(int)
        of1 = self.dir / f'train.{l1}'
        of2 = self.dir / f'train.{l2}'
        of3 = self.dir / f'train.meta.gz'
        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            for if1, if2 in zip(l1_files, l2_files):
                name = if1.name.rstrip(f'.{l1}')
                for seg1, seg2 in self.read_parallel(if1, if2):
                    w1.write(seg1 + '\n')
                    w2.write(seg2 + '\n')
                    w3.write(name + '\n')
                    counts[name] += 1
        total = sum(counts.values())
        counts = {'total': total, 'parts': counts}
        counts_msg = json.dumps(counts, indent=2)
        log.info('Train stats:\n' + counts_msg)
        IO.write_lines(self.dir / 'train.stats.json', counts_msg)
        return counts
コード例 #7
0
ファイル: data.py プロジェクト: thammegowda/mtdata
 def add_parts_sequential(self,
                          dir_path,
                          entries,
                          drop_noise=False,
                          compress=False,
                          desc=None,
                          fail_on_error=False):
     with pbar_man.counter(color='blue',
                           leave=False,
                           total=len(entries),
                           unit='it',
                           desc=desc,
                           autorefresh=True,
                           position=3) as pbar:
         for ent in entries:
             try:
                 n_good, n_bad = self.add_part(dir_path=dir_path,
                                               entry=ent,
                                               drop_noise=drop_noise,
                                               compress=compress)
                 if max(n_good, n_bad
                        ) >= 0:  # -1 for skipped record because it is valid
                     log.info(
                         f"{ent.did.name} : found {n_good:} segments and {n_bad:} errors"
                     )
                 pbar.update(force=True)
             except Exception as e:
                 log.error(f"Unable to add {ent.did}: {e}")
                 if fail_on_error:
                     raise e
                 msg = str(e).replace('\n', '\t')
                 self.errors_file.open('a').write(f"{ent.did}\t{msg}\n")
コード例 #8
0
ファイル: tmx.py プロジェクト: MGheini/mtdata
def parse_tmx(data, n_langs=2, log_every=DEF_PROGRESS):
    context = ET.iterparse(data, events=['end'])
    tus = (el for event, el in context if el.tag == 'tu')
    count, skips = 0, 0
    st = t = time.time()
    for tu in tus:
        langs, segs = [], []
        for tuv in tu.findall('tuv'):
            lang = [v for k, v in tuv.attrib.items() if k.endswith('lang')]
            if lang:
                langs.append(lang[0])
            seg = tuv.findtext('seg')
            if seg:
                segs.append(unescape(seg.strip()))
        if n_langs and len(segs) == len(langs) == n_langs:
            count += 1
            yield list(zip(langs, segs))
        else:
            skips += 1
            log.warning(
                f"Skipped: langs {langs} segs {len(segs)} ; Parsed count {count}"
            )
        if log_every and (time.time() - t) > log_every:
            elapsed = datetime.timedelta(seconds=round(time.time() - st))
            log.info(f"{elapsed} :: Parsed: {count:,} Skipped:{skips:,}")
            t = time.time()
        tu.clear()
    log.info(f"Skipped ={skips}; parsed: {count}")
コード例 #9
0
def main():
    args = parse_args()
    if args.reindex and cached_index_file.exists():
        bak_file = cached_index_file.with_suffix(".bak")
        log.info(f"Invalidate index: {cached_index_file} -> {bak_file}")
        cached_index_file.rename(bak_file)
    if args.task == 'list':
        list_data(args.langs,
                  args.names,
                  not_names=args.not_names,
                  full=args.full,
                  groups=args.groups,
                  not_groups=args.not_groups,
                  id_only=args.id)
    elif args.task == 'get':
        get_data(**vars(args))
    elif args.task == 'list-recipe':
        list_recipes()
    elif args.task == 'get-recipe':
        get_recipe(**vars(args))
    elif args.task == 'stats':
        show_stats(*args.did)
    elif args.task == 'report':
        generate_report(args.langs, names=args.names, not_names=args.not_names)
    else:
        raise Exception(f'{args.task} not implemented')
コード例 #10
0
ファイル: tmx.py プロジェクト: kpu/mtdata
def parse_tmx(data, log_every=DEF_PROGRESS):
    context = ET.iterparse(data, events=['end'])
    tus = (el for event, el in context if el.tag == 'tu')
    count = 0
    st = t = time.time()
    for tu in tus:
        lang_seg = {}
        for tuv in tu.findall('tuv'):
            lang = [v for k, v in tuv.attrib.items() if k.endswith('lang')]
            seg = tuv.findtext('seg')
            if lang and seg:
                lang = iso3_code(lang[0], fail_error=True)
                seg = unescape(seg.strip()).replace('\n',
                                                    ' ').replace('\t', ' ')
                if lang in lang_seg:
                    log.warning(
                        f"Language {lang} appears twice in same translation unit."
                    )
                lang_seg[lang] = seg
        yield lang_seg
        count += 1
        if log_every and (time.time() - t) > log_every:
            elapsed = datetime.timedelta(seconds=round(time.time() - st))
            log.info(f"{elapsed} :: Parsed: {count:,}")
            t = time.time()
        tu.clear()
コード例 #11
0
ファイル: __init__.py プロジェクト: masonreznov/mtdata
def load_all():
    from mtdata.index import (statmt, paracrawl, tilde, literature,
                              joshua_indian, globalvoices, unitednations,
                              wikimatrix, other, neulab_tedtalks)
    from mtdata.index.opus import opus_index, jw300

    counts = {}
    subsets = [
        ('Statmt.org', statmt.load),
        ('Paracrawl', paracrawl.load),
        ('Tilde', tilde.load),
        ('JoshuaIndianCoprus', joshua_indian.load_all),
        ('GlobalVoices', globalvoices.load_all),
        ('UnitedNations', unitednations.load_all),
        ('OPUS', opus_index.load_all),
        ('OPUS_JW300', jw300.load_all),
        ('WikiMatrix', wikimatrix.load_all),
        ('Other', other.load_all),
        ('Neulab_TEDTalksv1', neulab_tedtalks.load_all),
    ]
    for name, loader in subsets:
        n = len(INDEX)
        loader(INDEX)
        counts[name] = len(INDEX) - n
    counts['Total'] = len(INDEX)

    counts = '  '.join([f'{n}:{c:,}' for n, c in counts.items()])
    log.info(f"Loaded entries: {counts}")
    literature.load(INDEX)
コード例 #12
0
ファイル: main.py プロジェクト: MGheini/mtdata
def listing(langs, names, not_names=None, full=False, cache_dir=None):
    entries = get_entries(langs, names, not_names)
    log.info(f"Found {len(entries)}")
    for i, ent in enumerate(entries):
        print(ent.format(delim='\t'))
        if full:
            print(ent.cite or "CITATION_NOT_LISTED", end='\n\n')
    print(f"Total {len(entries)} entries")
コード例 #13
0
def main(inp, out, wmt21xml=False):
    parser = read_wmt21_xml if wmt21xml else read_sgm
    stream = parser(inp)
    with IO.writer(out) as out:
        count = 0
        for rec in stream:
            if isinstance(rec, str):
                rec = (rec, )
            line = '\t'.join(rec) + '\n'
            out.write(line)
            count += 1
        log.info(f"Wrote {count} lines to {out}")
コード例 #14
0
def lang_pair(string) -> LangPair:
    parts = string.strip().split('-')
    if len(parts) != 2:
        msg = f'expected value of form "xxx-yyz" eg "deu-eng"; given {string}'
        raise Exception(msg)
    std_codes = (bcp47(parts[0]), bcp47(parts[1]))
    std_form = '-'.join(str(lang) for lang in std_codes)
    if std_form != string:
        log.info(
            f"Suggestion: Use codes {std_form} instead of {string}."
            f" Let's make a little space for all languages of our planet 😢."
        )
    return std_codes
コード例 #15
0
    def get_instance(cls):
        if not cls.obj:
            if not cached_index_file.exists():
                log.info("Creating a fresh index object")
                cached_index_file.parent.mkdir(exist_ok=True)
                lock_file = cached_index_file.with_suffix("._lock")
                with portalocker.Lock(lock_file, 'w', timeout=60) as fh:
                    # got lock, check cache is not created by parallel processes while we waited
                    if not cached_index_file.exists():
                        obj = Index()
                        log.info("Indexing all datasets...")
                        obj.load_all()
                        log.info(
                            f"Caching my index file at {cached_index_file}")
                        with open(cached_index_file, 'wb') as out:
                            pickle.dump(obj, out)

            assert cached_index_file.exists()
            log.info(f"Loading index from cache {cached_index_file}")
            with open(cached_index_file, 'rb') as inp:
                obj = pickle.load(inp)

            assert isinstance(
                obj, cls
            ), f'{cached_index_file} isnt valid. please move or remove it'
            cls.obj = obj
        return cls.obj
コード例 #16
0
ファイル: opus_index.py プロジェクト: thammegowda/mtdata
def load_all(index: Index):

    url_pat = 'https://object.pouta.csc.fi/OPUS-{corpus}/{version}/moses/{l1}-{l2}.txt.zip'
    group_id = 'OPUS'
    citation = index.ref_db.get_bibtex('tiedemann2012parallel')
    skip_counts = defaultdict(int)
    dupes = defaultdict(set)
    assert data_file.exists()
    assert data_file.stat().st_size > 0

    with data_file.open() as lines:
        for line in lines:
            line = line.strip()
            if not line:  # empty lines in the top and bottom
                continue
            assert len(line.split('\t')) == 4, line
            corpus, version, l1, l2 = line.split('\t')
            url = url_pat.format(corpus=corpus, version=version, l1=l1, l2=l2)
            iso_l1, iso_l2 = bcp47.try_parse(
                l1, default=None), bcp47.try_parse(l2, default=None)
            if not iso_l1 or not iso_l2:
                if not iso_l1:
                    skip_counts[str(l1)] += 1
                if not iso_l2:
                    skip_counts[str(l2)] += 1
                continue
            version_cln = version.replace('-', '').lower()
            corpus_cln = corpus.replace('-', '_').lower()

            data_id = DatasetId(group=group_id,
                                name=corpus_cln,
                                version=version_cln,
                                langs=(iso_l1, iso_l2))
            if data_id in index:
                dupes[corpus].add(f'{l1}-{l2}')
                continue
            entry = Entry(did=data_id,
                          url=url,
                          cite=citation,
                          in_paths=[f'*.{l1}', f'*.{l2}'],
                          in_ext='txt')
            index.add_entry(entry)
        if skip_counts:
            skip_counts = list(
                sorted(dict(skip_counts).items(),
                       key=lambda x: x[1],
                       reverse=True))
            log.info(f"Skipped lang counts: {skip_counts}")
        if dupes:
            log.info(f"Duplicates langs: {dupes}")
コード例 #17
0
def read_sgm_xml(data: Path) -> Iterator[str]:
    """Extract sgm using XML parse
    This one breaks if there is any error in XML e.g. an & is not escaped ;
      see newstest2019-frde-ref.de.sgm for example!
    """
    with IO.reader(data) as data:
        context = ET.iterparse(data, events=['end'])
        segs = (el for event, el in context if el.tag == 'seg')
        count = 0
        for seg in segs:
            yield seg.text
            seg.clear()
            count += 1
        log.info(f"read {count} segments from {data}")
コード例 #18
0
 def extract(self, archive_file: Path, ext: str, x_dir: Path):
     assert archive_file.exists(), f'{archive_file} not found'
     x_dir.mkdir(parents=True, exist_ok=True)
     if ext in {'tar', 'tgz', 'tar.gz', 'tar.bz2', 'tbz2', 'tar.xz', 'txz'}:
         log.info(f"Going to extract tar {archive_file} --> {x_dir}")
         import tarfile
         with tarfile.open(archive_file) as tar:
             tar.extractall(path=x_dir)
     elif ext == 'zip':
         log.info(f"Going to extract zip {archive_file} --> {x_dir}")
         from zipfile import ZipFile
         with ZipFile(archive_file) as zip:
             zip.extractall(path=x_dir)
     else:
         raise Exception(f'"{ext}" type extraction not supported')
コード例 #19
0
ファイル: utils.py プロジェクト: thammegowda/mtdata
 def extract(self):
     dir_name = self.extracted_name()
     out_path = self.root.parent / dir_name
     valid_path = self.root.parent / (dir_name + '.valid')
     lock_path = self.root.parent / (dir_name + '.lock')
     if not valid_path.exists():
         with portalocker.Lock(lock_path, 'w',
                               timeout=FILE_LOCK_TIMEOUT) as _:
             if valid_path.exists():
                 return  # extracted by parallel process
             log.info(f"extracting {self.root}")
             with tarfile.open(self.root) as tar:
                 tar.extractall(out_path)
             valid_path.touch()
     return out_path
コード例 #20
0
ファイル: utils.py プロジェクト: thammegowda/mtdata
    def open_old(self, mode='r', **kwargs):
        assert mode in ('r', 'rt'), f'only "r" is supported, given: {mode}'
        log.info(f"Reading tar: {self.root}?{self.name}")
        container = tarfile.open(self.root, mode='r', encoding='utf-8')
        stream = container.extractfile(self.name)
        reader = io.TextIOWrapper(stream, **kwargs)
        reader_close = reader.close  # original close

        def close(*args, **kwargs):
            reader_close()
            stream.close()
            container and container.close()

        reader.close = close  # hijack
        return reader
コード例 #21
0
ファイル: main.py プロジェクト: masonreznov/mtdata
def get_data(args):
    assert args.train_names or args.test_names, 'Required --train or --test or both'
    dataset = Dataset.prepare(args.langs,
                              train_names=args.train_names,
                              test_names=args.test_names,
                              out_dir=args.out,
                              cache_dir=args.cache,
                              merge_train=args.merge)
    cli_sig = f'-l {"-".join(args.langs)}'
    cli_sig += f' -tr {" ".join(args.train_names)}' if args.train_names else ''
    cli_sig += f' -ts {" ".join(args.test_names)}' if args.test_names else ''
    sig = f'mtdata get {cli_sig} -o <out-dir>\nmtdata version {mtdata.__version__}\n'
    log.info(f'Dataset is ready at {dataset.dir}')
    log.info(f'mtdata args for reproducing this dataset:\n {sig}')
    with IO.writer(args.out / 'mtdata.signature.txt', append=True) as w:
        w.write(sig)
コード例 #22
0
ファイル: __init__.py プロジェクト: thammegowda/mtdata
 def load(cls, *paths) -> Dict[str, 'Recipe']:
     assert len(paths) > 0
     recipes = {}
     for path in paths:
         log.info(f"Loading recipes from {path}")
         with open(path) as inp:
             recipes_raw = yaml.load(inp)
         for r in recipes_raw:
             assert isinstance(r, dict), f'{r} expected to be a dict'
             try:
                 r = cls.parse(**r)
             except:
                 log.error(f"Error while parsing recipe:\n{r}")
                 raise
             assert r.id not in recipes, f'{r} is a duplicate'
             recipes[r.id] = r
     return recipes
コード例 #23
0
def read_tmx(path: Union[Path, str], langs=None):
    """
    reads a TMX file as records
    :param path: path to .tmx file
    :param langs: (lang1, lang2) codes eg (de, en); when it is None the code tries to auto detect
    :return: stream of (text1, text2)
    """
    passes = 0
    fails = 0
    if langs:
        assert len(langs) == 2
        langs = [bcp47(lang) for lang in langs]
        assert not BCP47Tag.are_compatible(
            *langs), f'{langs} expected to be different (/unambiguous)'
    with IO.reader(path) as data:
        recs = parse_tmx(data)
        for lang_seg in recs:
            if langs is None:
                log.warning(
                    "langs not set; this could result in language mismatch")
                if len(lang_seg) == 2:
                    langs = tuple(lang_seg.keys())
                else:
                    raise Exception(
                        f"Language autodetect for TMX only supports 2 languages,"
                        f" but provided with {lang_seg.keys()} in TMX {path}")
            seg1, seg2 = None, None
            for lang, seg in lang_seg.items():
                if BCP47Tag.are_compatible(langs[0], lang):
                    seg1 = seg
                elif BCP47Tag.are_compatible(langs[1], lang):
                    seg2 = seg
                # else ignore
            if seg1 and seg2:  # both segs are found
                yield seg1, seg2
                passes += 1
            else:
                fails += 1
    if passes == 0:
        if fails == 0:
            raise Exception(f"Empty TMX {path}")
        raise Exception(f"Nothing for {langs[0]}-{langs[1]} in TMX {path}")
    if fails != 0:
        log.warning(
            f"Skipped {fails} entries due to language mismatch in TMX {path}")
    log.info(f"Extracted {passes} pairs from TMX {path}")
コード例 #24
0
ファイル: data.py プロジェクト: kpu/mtdata
 def add_part(self, dir_path: Path, entry: Entry, drop_noise=False):
     path = self.cache.get_entry(entry)
     swap = entry.is_swap(self.langs)
     parser = Parser(path,
                     langs=self.langs,
                     ext=entry.in_ext or None,
                     ent=entry)
     langs = '_'.join(self.langs)
     l1 = (dir_path /
           f'{entry.name}-{langs}').with_suffix(f'.{self.langs[0]}')
     l2 = (dir_path /
           f'{entry.name}-{langs}').with_suffix(f'.{self.langs[1]}')
     mode = dict(mode='w', encoding='utf-8', errors='ignore')
     with l1.open(**mode) as f1, l2.open(**mode) as f2:
         count, skips, noise = 0, 0, 0
         for rec in parser.read_segs():
             rec = rec[:2]  # get the first two recs
             if len(rec) != 2:
                 skips += 1
                 continue
             if drop_noise and entry.is_noisy(seg1=rec[0], seg2=rec[1]):
                 skips += 1
                 noise += 1
                 continue
             sent1, sent2 = [s.strip() for s in rec]
             if not sent1 or not sent2:
                 skips += 1
                 continue
             if swap:
                 sent2, sent1 = sent1, sent2
             sent1 = sent1.replace('\n', ' ').replace('\t', ' ')
             sent2 = sent2.replace('\n', ' ').replace('\t', ' ')
             f1.write(f'{sent1}\n')
             f2.write(f'{sent2}\n')
             count += 1
         msg = f'Looks like an error. {count} segs are valid {skips} are invalid: {entry}'
         assert count > 0, msg
         if skips > count:
             log.warning(msg)
         if noise > 0:
             log.info(
                 f"{entry}: Noise : {noise:,}/{count:,} => {100*noise/count:.4f}%"
             )
         log.info(f"wrote {count} lines to {l1} == {l2}")
     return count, skips
コード例 #25
0
def read_sgm_regex(data: Path) -> Iterator[str]:
    """
    Extract sgm using regex.
    assumes each sgm is in its own line of form <seg id="xx"> yy</sgm>
    and line breaks are used between
    :param data:
    :return:
    """
    patt = re.compile(r'<seg id="(.*)">(.*)</seg>')
    count = 0
    with IO.reader(data) as data:
        for line in data:
            line = line.strip()
            match = patt.search(line)
            if match:
                yield unescape(match.group(2))
                count += 1
    log.info(f"read {count} segments from {data}")
コード例 #26
0
def get_data(langs,
             out_dir,
             train_dids=None,
             test_dids=None,
             dev_dids=None,
             merge_train=False,
             compress=False,
             drop_dupes=False,
             drop_tests=False,
             fail_on_error=False,
             n_jobs=DEF_N_JOBS,
             **kwargs):
    if kwargs:
        log.warning(f"Args are ignored: {kwargs}")
    from mtdata.data import Dataset
    assert train_dids or test_dids, 'Required --train or --test or both'
    dataset = Dataset.prepare(langs,
                              train_dids=train_dids,
                              test_dids=test_dids,
                              out_dir=out_dir,
                              dev_dids=dev_dids,
                              cache_dir=CACHE_DIR,
                              merge_train=merge_train,
                              compress=compress,
                              drop_dupes=drop_dupes,
                              drop_tests=drop_tests,
                              fail_on_error=fail_on_error,
                              n_jobs=n_jobs)
    cli_sig = f'-l {"-".join(str(l) for l in langs)}'
    for flag, dids in [('-tr', train_dids), ('-ts', test_dids),
                       ('-dv', dev_dids)]:
        if dids:
            cli_sig += f' {flag} {" ".join(map(str, dids))}'
    for flag, val in [('--merge', merge_train), ('--compress', compress),
                      ('-dd', drop_dupes), ('-dt', drop_tests)]:
        if val:
            cli_sig += ' ' + flag
    sig = f'mtdata get {cli_sig} -o <out-dir>\nmtdata version {mtdata.__version__}\n'
    log.info(f'Dataset is ready at {dataset.dir}')
    log.info(f'mtdata args for reproducing this dataset:\n {sig}')
    with IO.writer(out_dir / 'mtdata.signature.txt', append=True) as w:
        w.write(sig)
コード例 #27
0
def read_wmt21_xml(data):
    """
    This is a new XML format (instead of SGM) introduced in WMT2021
    :param data:
    :return:
    """
    tree = ET.parse(data)

    # 1. Assumes exactly one translation
    # 2. buffering in memory.  this can be improved with stream parsing
    def xpath_all(tree, xpath):
        return (clean(seg.text) for seg in tree.findall(xpath))

    srcs = list(xpath_all(tree.getroot(), xpath=".//src//seg"))
    tgts = list(xpath_all(tree.getroot(), xpath=".//ref//seg"))
    assert len(srcs) == len(
        tgts
    ), f'{data} has unequal number of segs: {len(srcs)} == {len(tgts)}?'
    yield from zip(srcs, tgts)
    log.info(f"Read {len(srcs)} segs from {data}")
コード例 #28
0
ファイル: data.py プロジェクト: thammegowda/mtdata
 def add_part_thread(self, args):
     fail_on_error = args.pop('fail_on_error', False)
     ent = args['entry']
     assert isinstance(ent, Entry)
     try:
         n_good, n_bad = self.add_part(**args)
         if max(n_good,
                n_bad) >= 0:  # -1 for skipped record because it is valid
             log.info(
                 f"{ent.did.name} : found {n_good:} segments and {n_bad:} errors"
             )
     except Exception as e:
         log.error(f"Unable to add {ent.did}: {e}")
         if fail_on_error:
             raise e
         msg = str(e).replace('\n', '\t')
         with portalocker.Lock(self.errors_file,
                               'a',
                               timeout=FILE_LOCK_TIMEOUT) as fh:
             # self.errors_file.open('a').write(f"{ent.did}\t{msg}\n")
             fh.write(f"{ent.did}\t{msg}\n")
コード例 #29
0
    def extract(self, archive_file: Path, ext: str, x_dir: Path):
        assert archive_file.exists(), f'{archive_file} not found'
        valid_file = self.get_flag_file(x_dir)
        lock_file = valid_file.with_suffix('._lock')
        if x_dir.exists() and valid_file.exists():
            return  # already extracted

        x_dir.mkdir(parents=True, exist_ok=True)
        log.info(
            f"Acquiring lock on {lock_file}\nif this gets stuck, delete the lock and restart"
        )
        with portalocker.Lock(lock_file, 'w', timeout=MAX_TIMEOUT) as fh:
            if valid_file.exists() and x_dir.exists():
                return
            if ext in {
                    'tar', 'tgz', 'tar.gz', 'tar.bz2', 'tbz2', 'tar.xz', 'txz'
            }:
                log.info(f"Going to extract tar {archive_file} --> {x_dir}")
                import tarfile
                with tarfile.open(archive_file) as tar:
                    tar.extractall(path=x_dir)
            elif ext == 'zip':
                log.info(f"Going to extract zip {archive_file} --> {x_dir}")
                from zipfile import ZipFile
                with ZipFile(archive_file) as zip:
                    zip.extractall(path=x_dir)
            else:
                raise Exception(f'"{ext}" type extraction not supported')

            valid_file.touch()
コード例 #30
0
def list_data(langs,
              names,
              not_names=None,
              full=False,
              groups=None,
              not_groups=None,
              id_only=False):
    from mtdata.index import get_entries
    entries = get_entries(langs,
                          names,
                          not_names,
                          groups=groups,
                          not_groups=not_groups,
                          fuzzy_match=True)
    for i, ent in enumerate(entries):
        if id_only:
            print(ent.did)
        else:
            print(ent.format(delim='\t'))
        if full:
            print(ent.cite or "CITATION_NOT_LISTED", end='\n\n')
    log.info(f"Total {len(entries)} entries")