def read_all(self): batch = [] max_len = 0 for ex in self.data: if min(len(ex.x), len(ex.y)) == 0: log.warning("Skipping a record, either source or target is empty") continue this_len = max(len(ex.x), len(ex.y)) if (len(batch) + 1) * max(max_len, this_len) <= self.max_toks and len(batch) < self.max_sents : batch.append(ex) # this one can go in max_len = max(max_len, this_len) else: if this_len > self.max_toks: raise Exception(f'Unable to make a batch of {self.max_toks} toks' f' with a seq of x_len:{len(ex.x)} y_len:{len(ex.y)}') # yield the current batch yield Batch(batch, sort_dec=self.sort_desc, batch_first=self.batch_first, meta=self.batch_meta) batch = [ex] # new batch max_len = this_len if batch: log.debug(f"\nLast batch, size={len(batch)}") yield Batch(batch, sort_dec=self.sort_desc, batch_first=self.batch_first, meta=self.batch_meta)
def get_init_vocab(cls, term_freqs, coverage: float = 0, line_count=None, min_freq=WORD_MIN_FREQ, vocab_size=-1): vocab = Reseved.with_reserved_types() res_stats = { r_type.name: term_freqs.pop(r_type.name) for r_type in vocab if r_type.name in term_freqs } if res_stats: log.warning(f"Found reserved types in corpus: {res_stats}") # Order of trimming techs: 1. coverage, 2. min freqs, 3. size cut off unk_count = 0 if coverage: assert 0 < coverage <= 1 term_freqs, coverage_unk_count = filter_types_coverage( term_freqs, coverage=coverage) unk_count += coverage_unk_count term_freqs = sorted(term_freqs.items(), key=lambda x: x[1], reverse=True) if min_freq and min_freq > 1: log.info( f"Excluding terms with freq < {min_freq}; |freq >= 1|: {len(term_freqs):,}" ) unk_count += sum(f for t, f in term_freqs if f < min_freq) term_freqs = [(t, f) for t, f in term_freqs if f >= min_freq] log.info(f"|freq >= {min_freq}| : {len(term_freqs):,}") if vocab_size > 0 and len(vocab) + len(term_freqs) > vocab_size: log.info(f"Truncating vocab at size={vocab_size}") unk_count += sum(f for t, f in term_freqs[vocab_size - len(vocab):]) term_freqs = term_freqs[:vocab_size - len(vocab)] # update reserved types with corpus freqs for idx, t in enumerate(vocab): freq = 0 if t.name in res_stats: freq = res_stats.pop(t.name) if idx == Reseved.UNK_IDX: freq += unk_count if idx in {Reseved.BOS_IDX, Reseved.EOS_IDX, Reseved.CLS_IDX } and line_count: freq += line_count if freq: log.warning( f"Update frequency for reserved type {t} with {freq}") vocab[idx] = t.copy(freq=freq) vocab += [ Type(name=name, idx=idx, freq=freq, level=cls.level) for idx, (name, freq) in enumerate(term_freqs, start=len(vocab)) ] log.info(f"Total {cls} vocab size {len(vocab):,}") return vocab
def filter_types_coverage(types: Dict[str, int], coverage=1.0) -> Tuple[Dict[str, int], int]: assert 0 < coverage <= 1 tot = sum(types.values()) includes = {} cum = 0 types = sorted(types.items(), key=lambda x: x[1], reverse=True) for t, f in types: cum += f / tot includes[t] = f if cum >= coverage: break log.info(f'Coverage={cum:g}; requested={coverage:g}') excludes = {ch: ct for ch, ct in types if ch not in includes} unk_count = sum(excludes.values()) log.warning(f'UNKed total toks:{unk_count} types={len(excludes)} from types:{excludes}') return includes, unk_count
def __init__(self, path, field_names: List[str], overwrite=False, max_parts=DEF_MAX_PARTS): self.field_names = field_names path = as_path(path) if path.exists() and len(os.listdir(path)) > 0: if overwrite: log.warning(f"Removing existing data at {path}") shutil.rmtree(path) else: raise Exception( f'{path} already exists. not overwriting it') path.mkdir(parents=True, exist_ok=True) self.path = path self.part_path_pad = part_path_pads(max_parts)
def learn_subwords(cls, term_freqs: Dict[str, int], vocab_size: int, min_co_evidence: int = DEF_MIN_CO_EV, init_vocab_factory=None) -> List[Type]: """ :param term_freqs: :param vocab_size: final vocab size: reserved + chars + user_specified + merges; special case, when `vocab_size=-1` the returned vocab will have just reserved + chars :param min_co_evidence: min co evidence for pair merges :param char_coverage: percentage of characters to be covered by inital char_freqs :param word_min_freq: words below this frequency will be excluded for learning BPE :return: List of Type """ log.info(f"Total types: {len(term_freqs)}") term_freqs = { cls.prepare_word(word): freq for word, freq in term_freqs.items() } char_freqs = coll.defaultdict(int) for term, freq in term_freqs.items(): for ch in term: char_freqs[ch] += freq """TODO: test this behavior; similar to subword-nmt v0.2 for ch in term[:-2]: # skip the last two: ending and the whitespace marker char_freqs[ch] += freq char_freqs[term[-2:]] += freq # ending + whitespace marker go together as a single byte """ init_vocab = init_vocab_factory( char_freqs) # create an initial vocabulary of chars if vocab_size == -1: log.warning( f'Since vocab_size={vocab_size}; not going to do any L1 merges' ) log.info(f'Found initial vocab size of {len(init_vocab)}') return init_vocab return cls._learn_codes(term_freqs, init_vocab, min_co_evidence=min_co_evidence, vocab_size=vocab_size)
def _make_eq_len_batch_ids(self, max_toks, max_sents, min_len=1): fields = list(self.fields.values()) rows = [] skip_counter = defaultdict(int) for _id in self.ids: lens = tuple(field.get_len(_id) for field in fields) if min(lens) < min_len: skip_counter[f'len < {min_len}'] += 1 else: rows.append((_id, max(lens))) if len(skip_counter) > 0: log.warning(f"Skipped :: {skip_counter}") rows = np.array(rows) # id, len np.random.shuffle( rows ) # in-place, along the first axis; for extra rand within len group rows = rows[rows[:, 1].argsort()] # sort by second col wiz len batches = [] batch = [] max_len = 0 for _id, _len in rows: if _len < 1: log.warning( f"Skipping record {_id}, either source or target is empty") continue if (len(batch) + 1) * max( max_len, _len) > max_toks or len(batch) > max_sents: if _len > max_toks: raise Exception( f'Unable to make a batch of {max_toks} toks' f' with a seq of len:{_len}') batches.append(np.array(batch)) batch = [] # new batch max_len = 0 batch.append(_id) # this one can go in max_len = max(max_len, _len) if batch: batches.append(np.array(batch)) return batches
def create_index(self, seqs): log.info( "Going to build corpus stats index; This might take lot of time and memory" ) n_seqs, n_ignored, n_replaced, bar_msg = 0, 0, 0, '' with tqdm(enumerate(seqs), unit='seqs', dynamic_ncols=True, mininterval=1) as data_bar: for idx, seq in data_bar: freq = 1 # default = 1 freq if isinstance(seq, tuple): # if freq is available seq, freq = seq n_seqs += 1 if idx == 0: # basic validation assert isinstance(seq, list) # first sequence, tokenized assert isinstance( seq[0], int) # sequence's item, should be an int or codepoint if not seq: log.warning(f"Skipping empty sequence at idx {idx + 1}") continue nodes = LnNode.from_seq(seq, freq=freq) assert len(seq) == len(nodes) for i in range(len(seq) - 1): # the last position left out bigm = (seq[i], seq[i + 1]) self.bi[bigm] += freq assert nodes[i] not in self.bi_ixs[bigm] self.bi_ixs[bigm].add(nodes[i]) # bigm found at node i self.uni[seq[i]] += freq self.uni[seq[ -1]] += freq # the last unigram count; not covered in the above loop bar_msg = f'MaxRSS={max_RSS()[1]}' data_bar.set_postfix_str(bar_msg, refresh=False) log.info(f"Created index; {bar_msg}")
def learn_codes(self, n_merges: int, min_co_evidence, code_level: int) -> List[Type]: """ :param n_merges: how many more merges :param min_co_evidence: min evidence (co-occurrence frequency); causes early stop upon failure :param code_level: what level to use for new code types created during merge for instance level=1 for word bpe; level=2 for seq bpe :return: """ uni, bi_ixs = self.uni, self.bi_ixs heap = MaxHeap(self.bi) heap_dirty = coll.defaultdict( int) # subtractions aren't updated in max-heap, they are here vocab = self.vocab for i in range(n_merges): # Using MaxHeap for faster lookup of max. But heap gets a bit dirty, so a bit of cleanup max_pair, pair_freq = heap.pop() while max_pair in heap_dirty: # clean all max [airs until a clean value freq_update = heap_dirty.pop(max_pair) assert freq_update < 0 # only decrements are valid. increments make this wrong corr_freq = pair_freq + freq_update # correct value assert corr_freq >= 0, f'{max_pair}:{pair_freq}, Δ={freq_update} = {corr_freq}' if corr_freq > 0: # exclude zero count heap.push(max_pair, corr_freq) max_pair, pair_freq = heap.pop() # here the actual loop begins if pair_freq < min_co_evidence: log.warning(f"Early stop; max evidence found is {pair_freq} " f"but min required is {min_co_evidence}") break new_type_idx = len(vocab) a, b = max_pair log.info( f"{(100 * i / n_merges):.2f}% :: {new_type_idx} || {a:4}:{uni[a]:5}" f" || {b:4}:{uni[b]:5} || {pair_freq:,} || {vocab[a].name} {vocab[b].name}" ) # code -> bigram (flatten out bigram; resolve interim codes new_type = Type(vocab[a].name + vocab[b].name, idx=new_type_idx, freq=pair_freq, level=code_level, kids=(vocab[a], vocab[b])) vocab.append(new_type) # updates: update bigram and unigram counts uni[new_type_idx] = pair_freq # this bigram is now a new unigram # unigram counts drop ; since some of their bigrams are removed uni[a] -= pair_freq uni[b] -= pair_freq heap_deltas = coll.defaultdict(int) update_nodes = bi_ixs.pop(max_pair) # also removed from bi_ixs for node in update_nodes: # -- x a b y -- x_node, b_node = node.left, node.right if node.is_unlinked or (a == b and new_type.idx in (node.val, b_node.val)): # this happens in the cases like "x a a a a y" uni[a] += node.freq uni[b] += node.freq uni[new_type.idx] -= node.freq continue y_node = b_node.right dirty = node.val != a or b_node.val != b # check that the linked list is proper if dirty: log.warning( f'Expected {a, b} but found {node.val, b_node.val}' f'\n {node, b_node}' f'\n--{vocab[a].signature()} ==' f' {vocab[node.val].signature() if node.val != a else "OK"}' f'\n--{vocab[b].signature()} ==' f' {vocab[b_node.val].signature() if b_node.val != b else "OK"}' ) log.warning( f"a={a}, b={b} || a_node={node}, b_node={b_node}") assert not dirty assert node.freq == b_node.freq # update : x a b y => x R y b_node.delete( unlink=True ) # delete() takes care of linking a → y and a ← y new_node = node # reuse a node as new_node/R new_node.val = new_type_idx # reuse a as new_node/R # Note: the above edits to a and b nodes do-not/should-not change __hash__ if x_node: # remove (x_node_val, a) from bi and bi_ixs heap_deltas[(x_node.val, a)] -= x_node.freq if bi_ixs.get((x_node.val, a)): # not sure why 'if' needed here; bi_ixs[(x_node.val, a)].remove(x_node) # add (x_node_val, R) to bi and bi_ixs heap_deltas[(x_node.val, new_type_idx)] += x_node.freq bi_ixs[(x_node.val, new_type_idx)].add(x_node) if y_node: # remove (b, y_node.val) from bi and bi_ixs heap_deltas[(b, y_node.val)] -= b_node.freq if bi_ixs.get((b, y_node.val)): # not sure why 'if' needed here; bi_ixs[(b, y_node.val)].remove(b_node) # add (R, y_node.val) to bi and bi_ixs heap_deltas[(new_type_idx, y_node.val)] += b_node.freq bi_ixs[(new_type_idx, y_node.val)].add(new_node) # however; the counts shouldn't go negative assert uni[a] >= 0 assert uni[b] >= 0 for pair, delta in heap_deltas.items(): if delta > 0: # these are new insertions, and they can go directly to heap assert new_type_idx in pair heap.push(pair, delta) elif delta < 0: # one of those subtractions, which cant be directly updated assert new_type_idx not in pair heap_dirty[pair] += delta return vocab