Пример #1
0
Файл: data.py Проект: kpu/mtdata
    def add_train_entries(self, entries, merge_train=False):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise)
        if not merge_train:
            return
        # merge
        l1, l2 = self.langs
        l1_files = list(self.train_parts_dir.glob(f"*.{l1}"))
        assert l1_files and len(l1_files) >= len(entries)

        l2_files = [l1_f.with_suffix(f".{l2}") for l1_f in l1_files]
        assert all(l2_f.exists() for l2_f in l2_files)
        log.info(f"Going to merge {len(l1_files)} files as one train file")
        counts = coll.defaultdict(int)
        of1 = self.dir / f'train.{l1}'
        of2 = self.dir / f'train.{l2}'
        of3 = self.dir / f'train.meta.gz'
        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            for if1, if2 in zip(l1_files, l2_files):
                name = if1.name.rstrip(f'.{l1}')
                for seg1, seg2 in self.read_parallel(if1, if2):
                    w1.write(seg1 + '\n')
                    w2.write(seg2 + '\n')
                    w3.write(name + '\n')
                    counts[name] += 1
        total = sum(counts.values())
        counts = {'total': total, 'parts': counts}
        counts_msg = json.dumps(counts, indent=2)
        log.info('Train stats:\n' + counts_msg)
        IO.write_lines(self.dir / 'train.stats.json', counts_msg)
        return counts
Пример #2
0
    def add_train_entries(self,
                          entries,
                          merge_train=False,
                          compress=False,
                          drop_hashes=None):
        self.add_parts(self.train_parts_dir,
                       entries,
                       drop_noise=self.drop_train_noise,
                       compress=compress,
                       desc='Training sets',
                       fail_on_error=self.fail_on_error)
        if not merge_train:
            return
        lang1, lang2 = self.langs
        # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2)
        paired_files = {}
        for ent in entries:
            e1, e2 = self.get_paths(self.train_parts_dir, ent)
            _, swapped = BCP47Tag.check_compat_swap(self.langs,
                                                    ent.did.langs,
                                                    fail_on_incompat=True)
            if swapped:
                e1, e2 = e2, e1
            paired_files[str(ent.did)] = e1, e2

        log.info(f"Going to merge {len(paired_files)} files as one train file")
        compress_ext = f'.{DEF_COMPRESS}' if compress else ''
        l1_ext = f'{lang1}{compress_ext}'
        l2_ext = f'{lang2}{compress_ext}'
        of1 = self.dir / f'train.{l1_ext}'
        of2 = self.dir / f'train.{l2_ext}'
        of3 = self.dir / f'train.meta.{DEF_COMPRESS}'

        counts = dict(total=coll.defaultdict(int),
                      dupes_skips=coll.defaultdict(int),
                      test_overlap_skips=coll.defaultdict(int),
                      selected=coll.defaultdict(int))
        train_hashes = set()

        with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3:
            with pbar_man.counter(color='green',
                                  total=len(paired_files),
                                  unit='it',
                                  desc="Merging",
                                  autorefresh=False) as pbar:
                for name, (if1, if2) in paired_files.items():
                    for seg1, seg2 in self.read_parallel(if1, if2):
                        counts['total'][name] += 1
                        if self.drop_dupes or self.drop_tests:
                            hash_val = hash((seg1, seg2))
                            if drop_hashes and (hash_val in drop_hashes
                                                or hash(seg1) in drop_hashes
                                                or hash(seg2) in drop_hashes):
                                counts['test_overlap_skips'][name] += 1
                                continue
                            if self.drop_dupes:
                                if hash_val in train_hashes:
                                    counts['dupes_skips'][name] += 1
                                    continue
                                train_hashes.add(hash_val)
                        w1.write(seg1 + '\n')
                        w2.write(seg2 + '\n')
                        w3.write(name + '\n')
                        counts['selected'][name] += 1
                    pbar.update()

        stats = dict(selected=sum(counts['selected'].values()),
                     total=sum(counts['total'].values()),
                     counts=counts)

        stats_msg = json.dumps(stats, indent=2)
        log.info('Train stats:\n' + stats_msg)
        IO.write_lines(self.dir / 'train.stats.json', stats_msg)
        return counts