def add_train_entries(self, entries, merge_train=False): self.add_parts(self.train_parts_dir, entries, drop_noise=self.drop_train_noise) if not merge_train: return # merge l1, l2 = self.langs l1_files = list(self.train_parts_dir.glob(f"*.{l1}")) assert l1_files and len(l1_files) >= len(entries) l2_files = [l1_f.with_suffix(f".{l2}") for l1_f in l1_files] assert all(l2_f.exists() for l2_f in l2_files) log.info(f"Going to merge {len(l1_files)} files as one train file") counts = coll.defaultdict(int) of1 = self.dir / f'train.{l1}' of2 = self.dir / f'train.{l2}' of3 = self.dir / f'train.meta.gz' with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3: for if1, if2 in zip(l1_files, l2_files): name = if1.name.rstrip(f'.{l1}') for seg1, seg2 in self.read_parallel(if1, if2): w1.write(seg1 + '\n') w2.write(seg2 + '\n') w3.write(name + '\n') counts[name] += 1 total = sum(counts.values()) counts = {'total': total, 'parts': counts} counts_msg = json.dumps(counts, indent=2) log.info('Train stats:\n' + counts_msg) IO.write_lines(self.dir / 'train.stats.json', counts_msg) return counts
def add_train_entries(self, entries, merge_train=False, compress=False, drop_hashes=None): self.add_parts(self.train_parts_dir, entries, drop_noise=self.drop_train_noise, compress=compress, desc='Training sets', fail_on_error=self.fail_on_error) if not merge_train: return lang1, lang2 = self.langs # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2) paired_files = {} for ent in entries: e1, e2 = self.get_paths(self.train_parts_dir, ent) _, swapped = BCP47Tag.check_compat_swap(self.langs, ent.did.langs, fail_on_incompat=True) if swapped: e1, e2 = e2, e1 paired_files[str(ent.did)] = e1, e2 log.info(f"Going to merge {len(paired_files)} files as one train file") compress_ext = f'.{DEF_COMPRESS}' if compress else '' l1_ext = f'{lang1}{compress_ext}' l2_ext = f'{lang2}{compress_ext}' of1 = self.dir / f'train.{l1_ext}' of2 = self.dir / f'train.{l2_ext}' of3 = self.dir / f'train.meta.{DEF_COMPRESS}' counts = dict(total=coll.defaultdict(int), dupes_skips=coll.defaultdict(int), test_overlap_skips=coll.defaultdict(int), selected=coll.defaultdict(int)) train_hashes = set() with IO.writer(of1) as w1, IO.writer(of2) as w2, IO.writer(of3) as w3: with pbar_man.counter(color='green', total=len(paired_files), unit='it', desc="Merging", autorefresh=False) as pbar: for name, (if1, if2) in paired_files.items(): for seg1, seg2 in self.read_parallel(if1, if2): counts['total'][name] += 1 if self.drop_dupes or self.drop_tests: hash_val = hash((seg1, seg2)) if drop_hashes and (hash_val in drop_hashes or hash(seg1) in drop_hashes or hash(seg2) in drop_hashes): counts['test_overlap_skips'][name] += 1 continue if self.drop_dupes: if hash_val in train_hashes: counts['dupes_skips'][name] += 1 continue train_hashes.add(hash_val) w1.write(seg1 + '\n') w2.write(seg2 + '\n') w3.write(name + '\n') counts['selected'][name] += 1 pbar.update() stats = dict(selected=sum(counts['selected'].values()), total=sum(counts['total'].values()), counts=counts) stats_msg = json.dumps(stats, indent=2) log.info('Train stats:\n' + stats_msg) IO.write_lines(self.dir / 'train.stats.json', stats_msg) return counts