def read_raw_bitext(spark, src_file: Union[str, Path], tgt_file: Union[str, Path], src_name='src_raw', tgt_name='tgt_raw') -> Tuple[DataFrame, int]: """ reads bitext to a dataframe :param spark: spark session :param src_file: source file to read from :param tgt_file: target file to read from :param src_name: name for source col in DF :param tgt_name: name for target col in DF :return: DataFrame """ if not isinstance(src_file, str): src_file = str(src_file) if not isinstance(tgt_file, str): tgt_file = str(tgt_file) src_df = spark.read.text(src_file).withColumnRenamed('value', src_name) tgt_df = spark.read.text(tgt_file).withColumnRenamed('value', tgt_name) n_src, n_tgt = src_df.count(), tgt_df.count() assert n_src == n_tgt, f'{n_src} == {n_tgt} ?' log.info(f"Found {n_src:,} parallel records in {src_file, tgt_file}") def with_idx(sdf): new_schema = StructType(sdf.schema.fields + [StructField("idx", LongType(), False), ]) return sdf.rdd.zipWithIndex().map(lambda row: row[0] + (row[1],)).toDF( schema=new_schema) src_df = with_idx(src_df) tgt_df = with_idx(tgt_df) bitext_df = src_df.join(tgt_df, 'idx', "inner") # n_bitext = bitext_df.count() # assert n_bitext == n_src, f'{n_bitext} == {n_src} ??' return bitext_df, n_src
def main(): args = parse_args() task = args.pop('task') if task == 'learn': args.pop('out') # No output args.pop('indices') # No output assert args.get( 'level'), 'argument --level is required for "learn" task' import time from datetime import timedelta from nlcodec.utils import max_RSS st = time.time() st_mem = max_RSS()[1] learn_vocab(**args) delta = timedelta(seconds=time.time() - st) et_mem = max_RSS()[1] log.info(f"Time taken: {delta}; Memory: {st_mem} --> {et_mem}") elif task in ('encode', 'decode'): scheme = load_scheme(args.pop('model')) inp, out, indices = args['inp'], args['out'], args.get( 'indices', False) if task == 'encode': recs = encode(inp, scheme, indices=indices) if indices: recs = ([str(idx) for idx in seq] for seq in recs) recs = (' '.join(seq) for seq in recs) else: recs = decode(inp, scheme, indices=indices) write_lines(recs, out) elif task == 'estimate': from nlcodec.qestim import estimate estimate(codec_path=args['model'], data=args['inp']) else: raise NotImplementedError(task + ' not implemented')
def __init__(self, table: List[Type]): super().__init__(table=table, invertible=False) self.root = self.make_vocab_prefix_trie(self.table) log.info( f"Vocab size={len(self)}; trie root has nodes={self.root.size}" f" but data_nodes={self.root.data_node_count}") assert self.unk_idx
def learn(cls, data: Iterator[str], vocab_size: int = 0, min_freq: int = WORD_MIN_FREQ, coverage: float = 0, term_freqs=False, **kwargs) -> List[Type]: """ :param data: input sentences :param vocab_size: max vocabulary size. :param min_freq: min frequency for inclusion in vocabulary. Excludes types with lower freq :param coverage: Character coverage :param term_freqs: is data the term_freqs ? :param kwargs: place holder for any extra args :return: """ assert not kwargs, f'{kwargs} args are not allowed/understood' if term_freqs: # input is term_freqs log.info("Restoring term frequencies from input") stats, line_count = cls.read_term_freqs(data=data) else: # compute term freqs log.info("Computing term frequencies from raw data") stats, line_count = cls.term_frequencies(data=data) return cls.get_init_vocab(stats, coverage, line_count, min_freq, vocab_size)
def __init__(self, data_path: Path, batch_size: Union[int, Tuple[int, int]], batch_meta: BatchMeta, sort_desc: bool = False, batch_first: bool = True, sort_by: str = None, keep_in_mem=False): """ Iterator for reading training data in batches :param data_path: path to TSV file :param batch_size: number of tokens on the target size per batch; or (max_toks, max_sents) :param sort_desc: should the mini batch be sorted by sequence len (useful for RNN api) :param keep_in_mem: keep all parts in memory for multipartdb; for single part, of course, the part remains in memory. """ self.batch_meta = batch_meta self.sort_desc = sort_desc if isinstance(batch_size, int): self.max_toks, self.max_sents = batch_size, float('inf') else: self.max_toks, self.max_sents = batch_size self.batch_first = batch_first self.sort_by = sort_by self.data_path = data_path self.keep_in_mem = keep_in_mem assert sort_by in (None, 'eq_len_rand_batch', 'random') if not isinstance(data_path, Path): data_path = Path(data_path) assert data_path.exists(), f'Invalid State: {data_path} is NOT found.' if data_path.is_file(): self.data = Db.load(data_path, rec_type=IdExample) elif self.data_path.is_dir(): self.data = MultipartDb.load(data_path, rec_type=IdExample, keep_in_mem=keep_in_mem) else: raise Exception(f'Invalid State: {data_path} is should be a file or dir.') log.info(f'Batch Size = {batch_size}, sort_by={sort_by}')
def learn(cls, data: Iterator[str], vocab_size: int = 0, min_freq=WORD_MIN_FREQ, coverage=CHAR_COVERAGE, min_co_evidence=MIN_CO_EV, term_freqs=False, **kwargs) -> List[Type]: assert vocab_size > 0 assert not kwargs, f'{kwargs} args are not allowed/understood' if term_freqs: log.info("Reading term freqs from input") tfs, line_count = WordScheme.read_term_freqs(data) else: log.info("Computing term freqs from input") tfs, line_count = WordScheme.term_frequencies(data) def init_vocab_factory(char_types): return CharScheme.get_init_vocab(char_types, line_count=line_count, coverage=coverage, min_freq=1) from .bpe import BPELearn vocab = BPELearn.learn_subwords(term_freqs=tfs, vocab_size=vocab_size, init_vocab_factory=init_vocab_factory, min_co_evidence=min_co_evidence) return vocab
def read_vocab( cls, inp: Union[Path, str, TextIO]) -> Tuple[List['Type'], Optional[Dict]]: if isinstance(inp, Path) or isinstance(inp, str): rdr = open(inp, 'r', encoding='utf8', errors='ignore') else: rdr = inp lines = list(l.strip() for l in rdr) meta = None if lines[0].startswith("#{"): # metadata such as version; not used as of now meta = json.loads(lines.pop(0)[1:]) # noinspection PyTypeChecker vocab: List[Type] = [None] * len(lines) for i, line in enumerate(lines): v = Type.parse(line=line.rstrip('\n'), vocab=vocab) assert v.idx == i vocab[i] = v if rdr is not inp: rdr.close() log.info(f"read {len(vocab)} types from {rdr.name}") return vocab, meta
def make_n_grams_all(sents: Iterable[List[Any]], n): grams = coll.Counter() n_sent = 0 for sent in tqdm(sents, mininterval=1, dynamic_ncols=True): grams.update(make_n_grams(sent, n)) n_sent += 1 log.info(f"Made {n}-grams: types={len(grams)}; tokens={sum(grams.values())}") return grams
def term_frequencies(cls, data: Iterator[str]) -> Tuple[Dict[str, int], int]: stats = coll.Counter() line_count = 0 for line in tqdm(data, mininterval=1): stats.update(cls.encode_str(line.strip())) line_count += 1 log.info( f"Found {len(stats):,} types and {sum(stats.values()):,} tokens") return stats, line_count
def rdd_as_db(rdd, db_path: Path, field_names=('x', 'y'), repartition=0, **kwargs) -> MultipartDb: # input format: each rec in rdd should be : (id, (x, y) if repartition and repartition > 0: rdd = rdd.repartition(repartition) with MultipartDb.Writer(db_path, field_names=field_names, **kwargs) as writer: n = rdd.mapPartitionsWithIndex(writer).count() # parallel write db = MultipartDb.load(db_path) log.info(f"Wrote {len(db)} recs in {len(db.part_paths)} parts at {db_path}") return db
def parallel_map(cls, mapper, collection, n_cpus=N_CPUS, name='', chunksize=1000): assert n_cpus > 1, f'at least 2 CPUs needed. chunksize={chunksize}' log.info(f"Going to use {n_cpus} parallel processes {name}") with mp.Pool(processes=n_cpus) as pool: yield from pool.imap(mapper, collection, chunksize=chunksize)
def __init__(self, path: Union[str,Path]): types, meta = Type.read_vocab(path) assert meta assert meta['max_level'] == Level.subword init_table = [t for t in types if t.level < Level.subword] self.merge_buffer = [t for t in types if t.level >= Level.subword] for m in self.merge_buffer: assert len(m.kids) == 2 # exactly two pieces super().__init__(table=init_table) # start with character self.root = self.make_vocab_prefix_trie(self.table) assert self.unk_idx log.info(f"initial table has {len(init_table)}; buffer has {len(self.merge_buffer)}")
def make_eq_len_ran_batches(self, max_toks, max_sents=float('inf')): batches = self._make_eq_len_batch_ids(max_toks=max_toks, max_sents=max_sents) if not batches: raise Exception(f'Found no data. Please check config data paths') log.info( f"length sorted random batches = {len(batches)}. Shuffling🔀...") # every pass introduce some randomness random.shuffle(batches) for batch_ids in batches: batch = [self[_id] for _id in batch_ids] yield batch
def learn_vocab(inp, level, model, vocab_size, min_freq=1, term_freqs=False, char_coverage=CHAR_COVERAGE, min_co_ev=MIN_CO_EV): if not min_freq or min_freq < 1: min_freq = WORD_MIN_FREQ if level == 'word' else CHAR_MIN_FREQ log.info(f"level={level} => default min_freq={min_freq}") else: log.info(f"level={level} => user given min_freq={min_freq}") log.info(f"Learn Vocab for level={level} and store at {model}") log.info(f"data ={inp}") Scheme = REGISTRY[level] args = {} if level != 'word': args['coverage'] = char_coverage # no char_coverage for word if level == 'bpe': args['min_co_evidence'] = min_co_ev table = Scheme.learn(inp, vocab_size=vocab_size, min_freq=min_freq, term_freqs=term_freqs, **args) Type.write_out(table=table, out=model)
def log_resources(name=""): """ logs time and memory utilized by a code block :param name: some name to identify code block :return: """ st = time.time() st_mem = max_RSS()[1] try: yield name finally: delta = timedelta(seconds=time.time() - st) et_mem = max_RSS()[1] log.info(f"{name} Time: {delta}; Mem: {st_mem} --> {et_mem}")
def filter_types_coverage(types: Dict[str, int], coverage=1.0) -> Tuple[Dict[str, int], int]: assert 0 < coverage <= 1 tot = sum(types.values()) includes = {} cum = 0 types = sorted(types.items(), key=lambda x: x[1], reverse=True) for t, f in types: cum += f / tot includes[t] = f if cum >= coverage: break log.info(f'Coverage={cum:g}; requested={coverage:g}') excludes = {ch: ct for ch, ct in types if ch not in includes} unk_count = sum(excludes.values()) log.warning(f'UNKed total toks:{unk_count} types={len(excludes)} from types:{excludes}') return includes, unk_count
def get_init_vocab(cls, term_freqs, coverage: float = 0, line_count=None, min_freq=WORD_MIN_FREQ, vocab_size=-1): vocab = Reseved.with_reserved_types() res_stats = { r_type.name: term_freqs.pop(r_type.name) for r_type in vocab if r_type.name in term_freqs } if res_stats: log.warning(f"Found reserved types in corpus: {res_stats}") # Order of trimming techs: 1. coverage, 2. min freqs, 3. size cut off unk_count = 0 if coverage: assert 0 < coverage <= 1 term_freqs, coverage_unk_count = filter_types_coverage( term_freqs, coverage=coverage) unk_count += coverage_unk_count term_freqs = sorted(term_freqs.items(), key=lambda x: x[1], reverse=True) if min_freq and min_freq > 1: log.info( f"Excluding terms with freq < {min_freq}; |freq >= 1|: {len(term_freqs):,}" ) unk_count += sum(f for t, f in term_freqs if f < min_freq) term_freqs = [(t, f) for t, f in term_freqs if f >= min_freq] log.info(f"|freq >= {min_freq}| : {len(term_freqs):,}") if vocab_size > 0 and len(vocab) + len(term_freqs) > vocab_size: log.info(f"Truncating vocab at size={vocab_size}") unk_count += sum(f for t, f in term_freqs[vocab_size - len(vocab):]) term_freqs = term_freqs[:vocab_size - len(vocab)] # update reserved types with corpus freqs for idx, t in enumerate(vocab): freq = 0 if t.name in res_stats: freq = res_stats.pop(t.name) if idx == Reseved.UNK_IDX: freq += unk_count if idx in {Reseved.BOS_IDX, Reseved.EOS_IDX, Reseved.CLS_IDX } and line_count: freq += line_count if freq: log.warning( f"Update frequency for reserved type {t} with {freq}") vocab[idx] = t.copy(freq=freq) vocab += [ Type(name=name, idx=idx, freq=freq, level=cls.level) for idx, (name, freq) in enumerate(term_freqs, start=len(vocab)) ] log.info(f"Total {cls} vocab size {len(vocab):,}") return vocab
def learn_subwords(cls, term_freqs: Dict[str, int], vocab_size: int, min_co_evidence: int = DEF_MIN_CO_EV, init_vocab_factory=None) -> List[Type]: """ :param term_freqs: :param vocab_size: final vocab size: reserved + chars + user_specified + merges; special case, when `vocab_size=-1` the returned vocab will have just reserved + chars :param min_co_evidence: min co evidence for pair merges :param char_coverage: percentage of characters to be covered by inital char_freqs :param word_min_freq: words below this frequency will be excluded for learning BPE :return: List of Type """ log.info(f"Total types: {len(term_freqs)}") term_freqs = { cls.prepare_word(word): freq for word, freq in term_freqs.items() } char_freqs = coll.defaultdict(int) for term, freq in term_freqs.items(): for ch in term: char_freqs[ch] += freq """TODO: test this behavior; similar to subword-nmt v0.2 for ch in term[:-2]: # skip the last two: ending and the whitespace marker char_freqs[ch] += freq char_freqs[term[-2:]] += freq # ending + whitespace marker go together as a single byte """ init_vocab = init_vocab_factory( char_freqs) # create an initial vocabulary of chars if vocab_size == -1: log.warning( f'Since vocab_size={vocab_size}; not going to do any L1 merges' ) log.info(f'Found initial vocab size of {len(init_vocab)}') return init_vocab return cls._learn_codes(term_freqs, init_vocab, min_co_evidence=min_co_evidence, vocab_size=vocab_size)
def validate_index(self): """ Call this any time to check if the index of uni bi bi_ixs are valid. Raises exception on invalid index :return: """ # This is code doesnt work with fast but new dirty heap updates max_code = max(self.uni) max_idx = max(t.idx for t in self.vocab) if not (max_code < self.vocab_size and max_code <= max_idx): raise ValueError( f'Vocab size is {self.vocab_size}, but max_code is {max_code}; max_idx={max_idx}' ) if not len(self.bi) == len(self.bi_ixs): raise ValueError( f"|bi|={len(self.bi)} and |bi_idxs|={len(self.bi_ixs)} are not same" ) for bigm, freq in self.bi.items(): if not freq >= 0: raise ValueError( f"{bigm} has freq {freq}; expected positive val") if not bigm in self.bi_ixs: raise ValueError(f"{bigm} exists in bi but not in bi_ixs") idx_freq = sum(n.freq for n in self.bi_ixs[bigm]) if not freq == idx_freq: raise ValueError( f"{bigm} has freq={freq} bi but has {idx_freq} bi_ixs refs" ) # less than unigram freqs if not freq <= self.uni[bigm[0]]: raise ValueError( f"{bigm} has freq={freq} bi but {bigm[0]} has {self.uni[bigm[0]]}" ) if not freq <= self.uni[bigm[1]]: raise ValueError( f"{bigm} has freq={freq} bi but {bigm[1]} has {self.uni[bigm[1]]}" ) for uni, freq in self.uni.items(): if not freq >= 0: raise ValueError( f"{uni} has freq={freq}; expected positive value") log.info(f"Index is valid")
def create_index(self, data): """ 1. encoder all : List[str] -> List[List[Int]] 2. Count unigram freqs => uni 2a) measure imbalance 2b) measure seq lens 3. Index all bigrams:: bi_ix (a,b) -> Node Oh, yeah, convert all List[Int] to Doubly Linked List Create seq_len: List[int] each node should have an integer index of sequence in corpus 4. ask bpe for next merge. gives (a,b) -> t a) count how many bigrams in index. insert uni[t] = count reduce uni[a], uni[b] by that count b) loop through each indexed node remove b. update a; update links :: x a b y => x t y insert bi_ix[x, t] and bi_ix[t, y] using seq number stored on node, decrement length by 1. assert len >= 1 c) measure imbalance, and avg length """ log.info("Encoding and creating index") enc_data = self.bpe.encode_parallel(data) uni: Dict[int, int] = coll.defaultdict(int) bi_ixs: Dict[Bigram, Set[LnNode]] = coll.defaultdict(set) seq_lens = [] with tqdm(enumerate(enc_data)) as data_bar: for idx, seq in data_bar: seq = LnNode.from_seq(seq, data=idx) seq_lens.append(len(seq)) for node in seq: uni[node.val] += 1 for node in seq[:-1]: bi_ixs[(node.val, node.right.val)].add(node) bar_msg = f'MaxRSS={max_RSS()[1]}' data_bar.set_postfix_str(bar_msg, refresh=False) self.seq_lens = np.array(seq_lens) log.info(f"Found {len(self.seq_lens)} sentences") for idx, freq in uni.items(): self.uni[idx] = freq self.bi_ixs = bi_ixs
def _learn_codes(cls, term_freqs: Dict[str, int], vocab: List[Type], vocab_size: int, init_list: List[str] = None, min_co_evidence: int = DEF_MIN_CO_EV) -> List[Type]: """ :param term_freqs: words types and frequencies :param vocab: initial vocab; usually reserved and alphabet :param vocab_size: desired vocabulary size :param init_list: any reserved words :return: List[str] word pieces """ vocab = copy.copy(vocab) log.info( f"Found {len(term_freqs)} types before splitting; initial vocab {len(vocab)}" ) if init_list: log.info(f'Adding {len(init_list)} types to the initial vocab') assert not any( ' ' in w for w in init_list), 'spaces not allowed in init_list words' vocab += [ Type(cls.prepare_word(w), level=Level.user, idx=idx, freq=0) for idx, w in enumerate(init_list, start=len(vocab)) ] rev_idx: Dict[str, int] = {word.name: word.idx for word in vocab} assert len(rev_idx) == len(vocab) # one to one map assert vocab_size > len(vocab), f'vocab_size={vocab_size} is too small;' \ f' found {len(vocab)} in the init vocab! Set a value larger than {len(vocab)}' seqs_freqs = cls._make_idxs(rev_idx, term_freqs) learner = BPELearn(seqs_freqs, vocab=vocab) final_vocab = learner.learn_codes(n_merges=vocab_size - len(vocab), min_co_evidence=min_co_evidence, code_level=Level.subword) return final_vocab
def write_out(cls, table: List['Type'], out: Union[Path, str, TextIO]): if isinstance(out, Path) or isinstance(out, str): wrtr = open(out, 'w', encoding='utf8', errors='ignore') else: wrtr = out levels = dict(coll.Counter(v.level for v in table)) max_level = max(levels.keys()) meta = dict(total=len(table), version=__version__, levels=levels, max_level=max_level, created=str(datetime.now())) meta = json.dumps(meta) wrtr.write(f"#{meta}\n") for i, item in enumerate(table): assert i == item.idx, f'{item} expected index {i}' wrtr.write(item.format() + '\n') if wrtr is not out: wrtr.close() log.info(f"Wrote {len(table)} to {wrtr.name}")
def create_index(self, seqs): log.info( "Going to build corpus stats index; This might take lot of time and memory" ) n_seqs, n_ignored, n_replaced, bar_msg = 0, 0, 0, '' with tqdm(enumerate(seqs), unit='seqs', dynamic_ncols=True, mininterval=1) as data_bar: for idx, seq in data_bar: freq = 1 # default = 1 freq if isinstance(seq, tuple): # if freq is available seq, freq = seq n_seqs += 1 if idx == 0: # basic validation assert isinstance(seq, list) # first sequence, tokenized assert isinstance( seq[0], int) # sequence's item, should be an int or codepoint if not seq: log.warning(f"Skipping empty sequence at idx {idx + 1}") continue nodes = LnNode.from_seq(seq, freq=freq) assert len(seq) == len(nodes) for i in range(len(seq) - 1): # the last position left out bigm = (seq[i], seq[i + 1]) self.bi[bigm] += freq assert nodes[i] not in self.bi_ixs[bigm] self.bi_ixs[bigm].add(nodes[i]) # bigm found at node i self.uni[seq[i]] += freq self.uni[seq[ -1]] += freq # the last unigram count; not covered in the above loop bar_msg = f'MaxRSS={max_RSS()[1]}' data_bar.set_postfix_str(bar_msg, refresh=False) log.info(f"Created index; {bar_msg}")
def learn_codes(self, n_merges: int, min_co_evidence, code_level: int) -> List[Type]: """ :param n_merges: how many more merges :param min_co_evidence: min evidence (co-occurrence frequency); causes early stop upon failure :param code_level: what level to use for new code types created during merge for instance level=1 for word bpe; level=2 for seq bpe :return: """ uni, bi_ixs = self.uni, self.bi_ixs heap = MaxHeap(self.bi) heap_dirty = coll.defaultdict( int) # subtractions aren't updated in max-heap, they are here vocab = self.vocab for i in range(n_merges): # Using MaxHeap for faster lookup of max. But heap gets a bit dirty, so a bit of cleanup max_pair, pair_freq = heap.pop() while max_pair in heap_dirty: # clean all max [airs until a clean value freq_update = heap_dirty.pop(max_pair) assert freq_update < 0 # only decrements are valid. increments make this wrong corr_freq = pair_freq + freq_update # correct value assert corr_freq >= 0, f'{max_pair}:{pair_freq}, Δ={freq_update} = {corr_freq}' if corr_freq > 0: # exclude zero count heap.push(max_pair, corr_freq) max_pair, pair_freq = heap.pop() # here the actual loop begins if pair_freq < min_co_evidence: log.warning(f"Early stop; max evidence found is {pair_freq} " f"but min required is {min_co_evidence}") break new_type_idx = len(vocab) a, b = max_pair log.info( f"{(100 * i / n_merges):.2f}% :: {new_type_idx} || {a:4}:{uni[a]:5}" f" || {b:4}:{uni[b]:5} || {pair_freq:,} || {vocab[a].name} {vocab[b].name}" ) # code -> bigram (flatten out bigram; resolve interim codes new_type = Type(vocab[a].name + vocab[b].name, idx=new_type_idx, freq=pair_freq, level=code_level, kids=(vocab[a], vocab[b])) vocab.append(new_type) # updates: update bigram and unigram counts uni[new_type_idx] = pair_freq # this bigram is now a new unigram # unigram counts drop ; since some of their bigrams are removed uni[a] -= pair_freq uni[b] -= pair_freq heap_deltas = coll.defaultdict(int) update_nodes = bi_ixs.pop(max_pair) # also removed from bi_ixs for node in update_nodes: # -- x a b y -- x_node, b_node = node.left, node.right if node.is_unlinked or (a == b and new_type.idx in (node.val, b_node.val)): # this happens in the cases like "x a a a a y" uni[a] += node.freq uni[b] += node.freq uni[new_type.idx] -= node.freq continue y_node = b_node.right dirty = node.val != a or b_node.val != b # check that the linked list is proper if dirty: log.warning( f'Expected {a, b} but found {node.val, b_node.val}' f'\n {node, b_node}' f'\n--{vocab[a].signature()} ==' f' {vocab[node.val].signature() if node.val != a else "OK"}' f'\n--{vocab[b].signature()} ==' f' {vocab[b_node.val].signature() if b_node.val != b else "OK"}' ) log.warning( f"a={a}, b={b} || a_node={node}, b_node={b_node}") assert not dirty assert node.freq == b_node.freq # update : x a b y => x R y b_node.delete( unlink=True ) # delete() takes care of linking a → y and a ← y new_node = node # reuse a node as new_node/R new_node.val = new_type_idx # reuse a as new_node/R # Note: the above edits to a and b nodes do-not/should-not change __hash__ if x_node: # remove (x_node_val, a) from bi and bi_ixs heap_deltas[(x_node.val, a)] -= x_node.freq if bi_ixs.get((x_node.val, a)): # not sure why 'if' needed here; bi_ixs[(x_node.val, a)].remove(x_node) # add (x_node_val, R) to bi and bi_ixs heap_deltas[(x_node.val, new_type_idx)] += x_node.freq bi_ixs[(x_node.val, new_type_idx)].add(x_node) if y_node: # remove (b, y_node.val) from bi and bi_ixs heap_deltas[(b, y_node.val)] -= b_node.freq if bi_ixs.get((b, y_node.val)): # not sure why 'if' needed here; bi_ixs[(b, y_node.val)].remove(b_node) # add (R, y_node.val) to bi and bi_ixs heap_deltas[(new_type_idx, y_node.val)] += b_node.freq bi_ixs[(new_type_idx, y_node.val)].add(new_node) # however; the counts shouldn't go negative assert uni[a] >= 0 assert uni[b] >= 0 for pair, delta in heap_deltas.items(): if delta > 0: # these are new insertions, and they can go directly to heap assert new_type_idx in pair heap.push(pair, delta) elif delta < 0: # one of those subtractions, which cant be directly updated assert new_type_idx not in pair heap_dirty[pair] += delta return vocab
def shrink_vocab(self, files: List[Path], min_freq: int, save_at: Optional[Path] = None) -> List[int]: """ :param files: :param min_freq: :param save_at: :return: """ """" - Accept a list of files - compute term frequencies - Eliminate types with zero counts - Preserve reserved types even if they have zero counts - Save the resulting model at a given file path - Return index mapping between old and new, so we can go back to model and shrink embedding tables """ from tqdm import tqdm freqs = coll.Counter() for file in files: log.info(f'Computing term frequencies from {file}') with IO.reader(file) as lines: for line in tqdm(lines): freqs.update(self.encode(line)) assert len(self.table) > max(freqs.keys()) removals = [False] * len(self.table) for idx, typ in enumerate(self.table): if typ.level == Level.reserved: continue # i.e. don't remove removals[idx] = freqs[ idx] < min_freq # remove if min_freq threshold not met # now make sure to preserve all the sub pieces leading to the pieces that retain for idx in range(len(self.table) - 1, -1, -1): combo = self.table[idx] assert combo.idx == idx if not removals[combo.idx] and combo.kids: # for piece in combo.kids: if removals[piece.idx]: removals[piece.idx] = False # dont remove this piece, mapping = [] for idx, (is_remove, typ) in enumerate(zip(removals, self.table)): assert idx == typ.idx if is_remove: continue mapping.append(idx) log.info( f"Shrinking vocab tables: {len(self.table)} --> {len(mapping)} ") rev_mapping = { old_idx: new_idx for new_idx, old_idx in enumerate(mapping) } old_table = self.table new_table = [] for new_idx, old_idx in enumerate(mapping): assert len(new_table) == new_idx old_t = old_table[old_idx] new_kids = [new_table[rev_mapping[k.idx]] for k in old_t.kids] if old_t.kids else None new_t = old_t.copy(idx=new_idx, kids=new_kids) new_table.append(new_t) if save_at: Type.write_out(new_table, out=save_at) return mapping
def track_merges(self, do_log=False): uni = self.uni bi_ixs = self.bi_ixs seq_lens = self.seq_lens res = [] while self.bpe.merges_remaining() > 0: """ ask bpe for next merge. gives (a,b) -> t a) count how many bigrams in index. insert uni[t] = count reduce uni[a], uni[b] by that count b) loop through each indexed node remove b. update a; update links :: x a b y => x t y insert bi_ix[x, t] and bi_ix[t, y] using seq number stored on node, decrement length by 1. assert len >= 1 c) measure imbalance, and avg length """ k = len(self.bpe) emd = self.earth_mov_dist(self.uni[:k]) mu_len = self.mean_seq_len() res.append((k, emd, mu_len)) if do_log: log.info(f"{k}: imbalance={emd:g} mean_seq_len={mu_len:g}." f"Next: {self.bpe.peek_merge().signature()}") new_type = self.bpe.merge() a_type, b_type = new_type.kids a, b = a_type.idx, b_type.idx update_nodes = bi_ixs.pop((a, b), None) if update_nodes: freq = sum(u.freq for u in update_nodes) uni[new_type.idx] = freq uni[a] -= freq uni[b] -= freq for u in update_nodes: b_node = u.right if u.is_unlinked or (a == b and new_type.idx in (u.val, b_node.val)): # happens due to repeats like X A A A A Y uni[new_type.idx] -= u.freq uni[a] += u.freq uni[b] -= u.freq continue assert u.val == a assert b_node.val == b, f'expected {a, b} found {u.val, b_node.val} at {b_node}' u.val = new_type.idx b_node.delete(unlink=True) seq_idx = u.data seq_lens[seq_idx] -= 1 assert seq_lens[seq_idx] > 0 if u.left: if bi_ixs.get((u.left.val, a)): bi_ixs[(u.left.val, a)].remove(u.left) bi_ixs[(u.left.val, u.val)].add(u.left) if u.right: if bi_ixs.get((b, u.right.val)): bi_ixs[(b, u.right.val)].remove(b_node) bi_ixs[(u.val, u.right.val)].add(u) k = len(self.bpe) emd = self.earth_mov_dist(self.uni[:k]) mu_len = self.mean_seq_len() res.append((k, emd, mu_len)) log.info(f"{k}: imbalance={emd:g} mean_seq_len={mu_len:g}.") return res