예제 #1
0
파일: data.py 프로젝트: kpu/mtdata
    def add_train_entries(self, entries, merge_train=False):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise)
        if not merge_train:
            return
        # merge
        l1, l2 = self.langs
        l1_files = list(self.train_parts_dir.glob(f"*.{l1}"))
        assert l1_files and len(l1_files) >= len(entries)

        l2_files = [l1_f.with_suffix(f".{l2}") for l1_f in l1_files]
        assert all(l2_f.exists() for l2_f in l2_files)
        log.info(f"Going to merge {len(l1_files)} files as one train file")
        counts = coll.defaultdict(int)
        of1 = self.dir / f'train.{l1}'
        of2 = self.dir / f'train.{l2}'
        of3 = self.dir / f'train.meta.gz'
        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            for if1, if2 in zip(l1_files, l2_files):
                name = if1.name.rstrip(f'.{l1}')
                for seg1, seg2 in self.read_parallel(if1, if2):
                    w1.write(seg1 + '\n')
                    w2.write(seg2 + '\n')
                    w3.write(name + '\n')
                    counts[name] += 1
        total = sum(counts.values())
        counts = {'total': total, 'parts': counts}
        counts_msg = json.dumps(counts, indent=2)
        log.info('Train stats:\n' + counts_msg)
        IO.write_lines(self.dir / 'train.stats.json', counts_msg)
        return counts
예제 #2
0
파일: data.py 프로젝트: thammegowda/mtdata
 def add_part(self,
              dir_path: Path,
              entry: Entry,
              drop_noise=False,
              compress=False):
     flag_file = dir_path / f'.valid.{entry.did}'
     if flag_file.exists():
         log.info(f"{flag_file} exits. Skipping")
         return -1, -1
     path = self.cache.get_entry(entry)
     # swap = entry.is_swap(self.langs)
     parser = Parser(path, ext=entry.in_ext or None, ent=entry)
     # langs = '_'.join(str(lang) for lang in self.langs)
     # Check that files are written in correct order
     l1, l2 = self.get_paths(dir_path, entry, compress=compress)
     io_args = dict(encoding='utf-8', errors='ignore')
     with IO.writer(l1, **io_args) as f1, IO.writer(l2, **io_args) as f2:
         count, skips, noise = 0, 0, 0
         for rec in parser.read_segs():
             rec = rec[:2]  # get the first two recs
             if len(rec) != 2:
                 skips += 1
                 continue
             if drop_noise and entry.is_noisy(seg1=rec[0], seg2=rec[1]):
                 skips += 1
                 noise += 1
                 continue
             sent1, sent2 = [s.strip() for s in rec]
             if not sent1 or not sent2:
                 skips += 1
                 continue
             sent1 = sent1.replace('\n',
                                   ' ').replace('\t',
                                                ' ').replace('\r', ' ')
             sent2 = sent2.replace('\n',
                                   ' ').replace('\t',
                                                ' ').replace('\r', ' ')
             f1.write(f'{sent1}\n')
             f2.write(f'{sent2}\n')
             count += 1
         msg = f'Looks like an error. {count} segs are valid {skips} are invalid: {entry}'
         assert count > 0, msg
         if skips > count:
             log.warning(msg)
         if noise > 0:
             log.info(
                 f"{entry}: Noise : {noise:,}/{count:,} => {100 * noise / count:.4f}%"
             )
         log.info(f"wrote {count} lines to {l1} == {l2}")
     flag_file.touch()
     return count, skips
예제 #3
0
파일: data.py 프로젝트: thammegowda/mtdata
 def cat_bitexts(self, in_paths: List[Tuple[Path, Path]],
                 out_paths: Tuple[Path, Path]):
     of1, of2 = out_paths
     of1.parent.mkdir(exist_ok=True)
     of2.parent.mkdir(exist_ok=True)
     with pbar_man.counter(color='green', total=len(in_paths), unit='it', desc="Merging") as pbar, \
             IO.writer(of1) as w1, IO.writer(of2) as w2:
         for if1, if2 in in_paths:
             assert if1.exists()
             assert if2.exists()
             for seg1, seg2 in self.read_parallel(if1, if2):
                 w1.write(seg1 + '\n')
                 w2.write(seg2 + '\n')
             pbar.update()
예제 #4
0
파일: tmx.py 프로젝트: kpu/mtdata
def main(inp, out, langs):
    recs = read_tmx(inp, langs=langs)
    with IO.writer(out) as out:
        count = 0
        for rec in recs:
            rec = [l.replace('\t', ' ') for l in rec]
            out.write('\t'.join(rec) + '\n')
            count += 1
        log.warning(f"Wrote {count} lines to {out}")
예제 #5
0
def main(inp, out):
    segs = read_sgm(inp)
    with IO.writer(out) as out:
        count = 0
        for seg in segs:
            seg = seg.replace('\t', ' ')
            out.write(seg + '\n')
            count += 1
        log.warning(f"Wrote {count} lines to {out}")
예제 #6
0
def main(inp, out, wmt21xml=False):
    parser = read_wmt21_xml if wmt21xml else read_sgm
    stream = parser(inp)
    with IO.writer(out) as out:
        count = 0
        for rec in stream:
            if isinstance(rec, str):
                rec = (rec, )
            line = '\t'.join(rec) + '\n'
            out.write(line)
            count += 1
        log.info(f"Wrote {count} lines to {out}")
예제 #7
0
파일: main.py 프로젝트: masonreznov/mtdata
def get_data(args):
    assert args.train_names or args.test_names, 'Required --train or --test or both'
    dataset = Dataset.prepare(args.langs,
                              train_names=args.train_names,
                              test_names=args.test_names,
                              out_dir=args.out,
                              cache_dir=args.cache,
                              merge_train=args.merge)
    cli_sig = f'-l {"-".join(args.langs)}'
    cli_sig += f' -tr {" ".join(args.train_names)}' if args.train_names else ''
    cli_sig += f' -ts {" ".join(args.test_names)}' if args.test_names else ''
    sig = f'mtdata get {cli_sig} -o <out-dir>\nmtdata version {mtdata.__version__}\n'
    log.info(f'Dataset is ready at {dataset.dir}')
    log.info(f'mtdata args for reproducing this dataset:\n {sig}')
    with IO.writer(args.out / 'mtdata.signature.txt', append=True) as w:
        w.write(sig)
예제 #8
0
def get_data(langs,
             out_dir,
             train_dids=None,
             test_dids=None,
             dev_dids=None,
             merge_train=False,
             compress=False,
             drop_dupes=False,
             drop_tests=False,
             fail_on_error=False,
             n_jobs=DEF_N_JOBS,
             **kwargs):
    if kwargs:
        log.warning(f"Args are ignored: {kwargs}")
    from mtdata.data import Dataset
    assert train_dids or test_dids, 'Required --train or --test or both'
    dataset = Dataset.prepare(langs,
                              train_dids=train_dids,
                              test_dids=test_dids,
                              out_dir=out_dir,
                              dev_dids=dev_dids,
                              cache_dir=CACHE_DIR,
                              merge_train=merge_train,
                              compress=compress,
                              drop_dupes=drop_dupes,
                              drop_tests=drop_tests,
                              fail_on_error=fail_on_error,
                              n_jobs=n_jobs)
    cli_sig = f'-l {"-".join(str(l) for l in langs)}'
    for flag, dids in [('-tr', train_dids), ('-ts', test_dids),
                       ('-dv', dev_dids)]:
        if dids:
            cli_sig += f' {flag} {" ".join(map(str, dids))}'
    for flag, val in [('--merge', merge_train), ('--compress', compress),
                      ('-dd', drop_dupes), ('-dt', drop_tests)]:
        if val:
            cli_sig += ' ' + flag
    sig = f'mtdata get {cli_sig} -o <out-dir>\nmtdata version {mtdata.__version__}\n'
    log.info(f'Dataset is ready at {dataset.dir}')
    log.info(f'mtdata args for reproducing this dataset:\n {sig}')
    with IO.writer(out_dir / 'mtdata.signature.txt', append=True) as w:
        w.write(sig)
예제 #9
0
파일: data.py 프로젝트: thammegowda/mtdata
    def add_train_entries(self,
                          entries,
                          merge_train=False,
                          compress=False,
                          drop_hashes=None):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise,
                       compress=compress,
                       desc='Training sets',
                       fail_on_error=self.fail_on_error)
        if not merge_train:
            return
        lang1, lang2 = self.langs
        # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2)
        paired_files = {}
        for ent in entries:
            e1, e2 = self.get_paths(self.train_parts_dir, ent)
            _, swapped = BCP47Tag.check_compat_swap(self.langs,
                                                    ent.did.langs,
                                                    fail_on_incompat=True)
            if swapped:
                e1, e2 = e2, e1
            paired_files[str(ent.did)] = e1, e2

        log.info(f"Going to merge {len(paired_files)} files as one train file")
        compress_ext = f'.{DEF_COMPRESS}' if compress else ''
        l1_ext = f'{lang1}{compress_ext}'
        l2_ext = f'{lang2}{compress_ext}'
        of1 = self.dir / f'train.{l1_ext}'
        of2 = self.dir / f'train.{l2_ext}'
        of3 = self.dir / f'train.meta.{DEF_COMPRESS}'

        counts = dict(total=coll.defaultdict(int),
                      dupes_skips=coll.defaultdict(int),
                      test_overlap_skips=coll.defaultdict(int),
                      selected=coll.defaultdict(int))
        train_hashes = set()

        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            with pbar_man.counter(color='green',
                                  total=len(paired_files),
                                  unit='it',
                                  desc="Merging",
                                  autorefresh=False) as pbar:
                for name, (if1, if2) in paired_files.items():
                    for seg1, seg2 in self.read_parallel(if1, if2):
                        counts['total'][name] += 1
                        if self.drop_dupes or self.drop_tests:
                            hash_val = hash((seg1, seg2))
                            if drop_hashes and (hash_val in drop_hashes
                                                or hash(seg1) in drop_hashes
                                                or hash(seg2) in drop_hashes):
                                counts['test_overlap_skips'][name] += 1
                                continue
                            if self.drop_dupes:
                                if hash_val in train_hashes:
                                    counts['dupes_skips'][name] += 1
                                    continue
                                train_hashes.add(hash_val)
                        w1.write(seg1 + '\n')
                        w2.write(seg2 + '\n')
                        w3.write(name + '\n')
                        counts['selected'][name] += 1
                    pbar.update()

        stats = dict(selected=sum(counts['selected'].values()),
                     total=sum(counts['total'].values()),
                     counts=counts)

        stats_msg = json.dumps(stats, indent=2)
        log.info('Train stats:\n' + stats_msg)
        IO.write_lines(self.dir / 'train.stats.json', stats_msg)
        return counts