Ejemplo n.º 1
0
def read_raw_bitext(spark, src_file: Union[str, Path], tgt_file: Union[str, Path],
                    src_name='src_raw', tgt_name='tgt_raw') -> Tuple[DataFrame, int]:
    """
    reads bitext to a dataframe
    :param spark:  spark session
    :param src_file: source file to read from
    :param tgt_file:  target file to read from
    :param src_name: name for source col in DF
    :param tgt_name: name for target col in DF
    :return: DataFrame
    """
    if not isinstance(src_file, str):
        src_file = str(src_file)
    if not isinstance(tgt_file, str):
        tgt_file = str(tgt_file)

    src_df = spark.read.text(src_file).withColumnRenamed('value', src_name)
    tgt_df = spark.read.text(tgt_file).withColumnRenamed('value', tgt_name)

    n_src, n_tgt = src_df.count(), tgt_df.count()
    assert n_src == n_tgt, f'{n_src} == {n_tgt} ?'
    log.info(f"Found {n_src:,} parallel records in {src_file, tgt_file}")

    def with_idx(sdf):
        new_schema = StructType(sdf.schema.fields + [StructField("idx", LongType(), False), ])
        return sdf.rdd.zipWithIndex().map(lambda row: row[0] + (row[1],)).toDF(
            schema=new_schema)

    src_df = with_idx(src_df)
    tgt_df = with_idx(tgt_df)
    bitext_df = src_df.join(tgt_df, 'idx', "inner")
    # n_bitext = bitext_df.count()
    # assert n_bitext == n_src, f'{n_bitext} == {n_src} ??'
    return bitext_df, n_src
Ejemplo n.º 2
0
def main():
    args = parse_args()
    task = args.pop('task')
    if task == 'learn':
        args.pop('out')  # No output
        args.pop('indices')  # No output
        assert args.get(
            'level'), 'argument --level is required for "learn" task'
        import time
        from datetime import timedelta
        from nlcodec.utils import max_RSS
        st = time.time()
        st_mem = max_RSS()[1]
        learn_vocab(**args)
        delta = timedelta(seconds=time.time() - st)
        et_mem = max_RSS()[1]
        log.info(f"Time taken: {delta}; Memory: {st_mem} --> {et_mem}")
    elif task in ('encode', 'decode'):
        scheme = load_scheme(args.pop('model'))
        inp, out, indices = args['inp'], args['out'], args.get(
            'indices', False)
        if task == 'encode':
            recs = encode(inp, scheme, indices=indices)
            if indices:
                recs = ([str(idx) for idx in seq] for seq in recs)
            recs = (' '.join(seq) for seq in recs)
        else:
            recs = decode(inp, scheme, indices=indices)
        write_lines(recs, out)
    elif task == 'estimate':
        from nlcodec.qestim import estimate
        estimate(codec_path=args['model'], data=args['inp'])
    else:
        raise NotImplementedError(task + ' not implemented')
Ejemplo n.º 3
0
 def __init__(self, table: List[Type]):
     super().__init__(table=table, invertible=False)
     self.root = self.make_vocab_prefix_trie(self.table)
     log.info(
         f"Vocab size={len(self)}; trie root has nodes={self.root.size}"
         f" but data_nodes={self.root.data_node_count}")
     assert self.unk_idx
Ejemplo n.º 4
0
    def learn(cls,
              data: Iterator[str],
              vocab_size: int = 0,
              min_freq: int = WORD_MIN_FREQ,
              coverage: float = 0,
              term_freqs=False,
              **kwargs) -> List[Type]:
        """
        :param data: input sentences
        :param vocab_size: max vocabulary size.
        :param min_freq: min frequency for inclusion in vocabulary. Excludes types with lower freq
        :param coverage: Character coverage
        :param term_freqs: is data the term_freqs ?
        :param kwargs: place holder for any extra args
        :return:
        """
        assert not kwargs, f'{kwargs} args are not allowed/understood'
        if term_freqs:  # input is term_freqs
            log.info("Restoring term frequencies from input")
            stats, line_count = cls.read_term_freqs(data=data)
        else:  # compute term freqs
            log.info("Computing term frequencies from raw data")
            stats, line_count = cls.term_frequencies(data=data)

        return cls.get_init_vocab(stats, coverage, line_count, min_freq,
                                  vocab_size)
Ejemplo n.º 5
0
    def __init__(self, data_path: Path, batch_size: Union[int, Tuple[int, int]], batch_meta: BatchMeta,
                 sort_desc: bool = False, batch_first: bool = True, sort_by: str = None,
                 keep_in_mem=False):
        """
        Iterator for reading training data in batches
        :param data_path: path to TSV file
        :param batch_size: number of tokens on the target size per batch; or (max_toks, max_sents)
        :param sort_desc: should the mini batch be sorted by sequence len (useful for RNN api)
        :param keep_in_mem: keep all parts in memory for multipartdb;
           for single part, of course, the part remains in memory.
        """
        self.batch_meta = batch_meta
        self.sort_desc = sort_desc
        if isinstance(batch_size, int):
            self.max_toks, self.max_sents = batch_size, float('inf')
        else:
            self.max_toks, self.max_sents = batch_size
        self.batch_first = batch_first
        self.sort_by = sort_by
        self.data_path = data_path
        self.keep_in_mem = keep_in_mem
        assert sort_by in (None, 'eq_len_rand_batch', 'random')
        if not isinstance(data_path, Path):
            data_path = Path(data_path)

        assert data_path.exists(), f'Invalid State: {data_path} is NOT found.'
        if data_path.is_file():
            self.data = Db.load(data_path, rec_type=IdExample)
        elif self.data_path.is_dir():
            self.data = MultipartDb.load(data_path, rec_type=IdExample, keep_in_mem=keep_in_mem)
        else:
            raise Exception(f'Invalid State: {data_path} is should be a file or dir.')

        log.info(f'Batch Size = {batch_size}, sort_by={sort_by}')
Ejemplo n.º 6
0
    def learn(cls,
              data: Iterator[str],
              vocab_size: int = 0,
              min_freq=WORD_MIN_FREQ,
              coverage=CHAR_COVERAGE,
              min_co_evidence=MIN_CO_EV,
              term_freqs=False,
              **kwargs) -> List[Type]:
        assert vocab_size > 0
        assert not kwargs, f'{kwargs} args are not allowed/understood'
        if term_freqs:
            log.info("Reading term freqs from input")
            tfs, line_count = WordScheme.read_term_freqs(data)
        else:
            log.info("Computing term freqs from input")
            tfs, line_count = WordScheme.term_frequencies(data)

        def init_vocab_factory(char_types):
            return CharScheme.get_init_vocab(char_types,
                                             line_count=line_count,
                                             coverage=coverage,
                                             min_freq=1)

        from .bpe import BPELearn
        vocab = BPELearn.learn_subwords(term_freqs=tfs,
                                        vocab_size=vocab_size,
                                        init_vocab_factory=init_vocab_factory,
                                        min_co_evidence=min_co_evidence)
        return vocab
Ejemplo n.º 7
0
    def read_vocab(
            cls, inp: Union[Path, str,
                            TextIO]) -> Tuple[List['Type'], Optional[Dict]]:

        if isinstance(inp, Path) or isinstance(inp, str):
            rdr = open(inp, 'r', encoding='utf8', errors='ignore')
        else:
            rdr = inp

        lines = list(l.strip() for l in rdr)
        meta = None
        if lines[0].startswith("#{"):
            # metadata such as version; not used as of now
            meta = json.loads(lines.pop(0)[1:])

        # noinspection PyTypeChecker
        vocab: List[Type] = [None] * len(lines)
        for i, line in enumerate(lines):
            v = Type.parse(line=line.rstrip('\n'), vocab=vocab)
            assert v.idx == i
            vocab[i] = v
        if rdr is not inp:
            rdr.close()
        log.info(f"read {len(vocab)} types from {rdr.name}")
        return vocab, meta
Ejemplo n.º 8
0
def make_n_grams_all(sents: Iterable[List[Any]], n):
    grams = coll.Counter()
    n_sent = 0
    for sent in tqdm(sents, mininterval=1, dynamic_ncols=True):
        grams.update(make_n_grams(sent, n))
        n_sent += 1
    log.info(f"Made {n}-grams: types={len(grams)}; tokens={sum(grams.values())}")
    return grams
Ejemplo n.º 9
0
 def term_frequencies(cls,
                      data: Iterator[str]) -> Tuple[Dict[str, int], int]:
     stats = coll.Counter()
     line_count = 0
     for line in tqdm(data, mininterval=1):
         stats.update(cls.encode_str(line.strip()))
         line_count += 1
     log.info(
         f"Found {len(stats):,} types and {sum(stats.values()):,} tokens")
     return stats, line_count
Ejemplo n.º 10
0
def rdd_as_db(rdd, db_path: Path, field_names=('x', 'y'), repartition=0, **kwargs) -> MultipartDb:
    # input format: each rec in rdd should be : (id, (x, y)
    if repartition and repartition > 0:
        rdd = rdd.repartition(repartition)

    with MultipartDb.Writer(db_path, field_names=field_names, **kwargs) as writer:
        n = rdd.mapPartitionsWithIndex(writer).count()  # parallel write
    db = MultipartDb.load(db_path)
    log.info(f"Wrote {len(db)} recs in {len(db.part_paths)} parts at {db_path}")
    return db
Ejemplo n.º 11
0
 def parallel_map(cls,
                  mapper,
                  collection,
                  n_cpus=N_CPUS,
                  name='',
                  chunksize=1000):
     assert n_cpus > 1, f'at least 2 CPUs needed. chunksize={chunksize}'
     log.info(f"Going to use {n_cpus} parallel processes {name}")
     with mp.Pool(processes=n_cpus) as pool:
         yield from pool.imap(mapper, collection, chunksize=chunksize)
Ejemplo n.º 12
0
    def __init__(self, path: Union[str,Path]):
        types, meta = Type.read_vocab(path)
        assert meta
        assert meta['max_level'] == Level.subword
        init_table = [t for t in types if t.level < Level.subword]
        self.merge_buffer = [t for t in types if t.level >= Level.subword]
        for m in self.merge_buffer:
            assert len(m.kids) == 2  # exactly two pieces

        super().__init__(table=init_table) # start with character
        self.root = self.make_vocab_prefix_trie(self.table)
        assert self.unk_idx
        log.info(f"initial table has {len(init_table)}; buffer has {len(self.merge_buffer)}")
Ejemplo n.º 13
0
    def make_eq_len_ran_batches(self, max_toks, max_sents=float('inf')):

        batches = self._make_eq_len_batch_ids(max_toks=max_toks,
                                              max_sents=max_sents)
        if not batches:
            raise Exception(f'Found no data. Please check config data paths')
        log.info(
            f"length sorted random batches = {len(batches)}. Shuffling🔀...")
        # every pass introduce some randomness
        random.shuffle(batches)

        for batch_ids in batches:
            batch = [self[_id] for _id in batch_ids]
            yield batch
Ejemplo n.º 14
0
def learn_vocab(inp,
                level,
                model,
                vocab_size,
                min_freq=1,
                term_freqs=False,
                char_coverage=CHAR_COVERAGE,
                min_co_ev=MIN_CO_EV):
    if not min_freq or min_freq < 1:
        min_freq = WORD_MIN_FREQ if level == 'word' else CHAR_MIN_FREQ
        log.info(f"level={level} => default min_freq={min_freq}")
    else:
        log.info(f"level={level} => user given min_freq={min_freq}")
    log.info(f"Learn Vocab for level={level} and store at {model}")
    log.info(f"data ={inp}")
    Scheme = REGISTRY[level]
    args = {}
    if level != 'word':
        args['coverage'] = char_coverage  # no char_coverage for word
    if level == 'bpe':
        args['min_co_evidence'] = min_co_ev
    table = Scheme.learn(inp,
                         vocab_size=vocab_size,
                         min_freq=min_freq,
                         term_freqs=term_freqs,
                         **args)
    Type.write_out(table=table, out=model)
Ejemplo n.º 15
0
def log_resources(name=""):
    """
    logs time and memory utilized by a code block
    :param name: some name to identify code block
    :return:
    """
    st = time.time()
    st_mem = max_RSS()[1]
    try:
        yield name
    finally:
        delta = timedelta(seconds=time.time() - st)
        et_mem = max_RSS()[1]
        log.info(f"{name} Time: {delta}; Mem: {st_mem} --> {et_mem}")
Ejemplo n.º 16
0
def filter_types_coverage(types: Dict[str, int], coverage=1.0) -> Tuple[Dict[str, int], int]:
    assert  0 < coverage <= 1
    tot = sum(types.values())
    includes = {}
    cum = 0
    types  = sorted(types.items(), key=lambda x: x[1], reverse=True)
    for t, f in types:
        cum += f / tot
        includes[t] = f
        if cum >= coverage:
            break
    log.info(f'Coverage={cum:g}; requested={coverage:g}')
    excludes = {ch: ct for ch, ct in types if ch not in includes}
    unk_count = sum(excludes.values())
    log.warning(f'UNKed total toks:{unk_count} types={len(excludes)} from types:{excludes}')
    return includes, unk_count
Ejemplo n.º 17
0
    def get_init_vocab(cls,
                       term_freqs,
                       coverage: float = 0,
                       line_count=None,
                       min_freq=WORD_MIN_FREQ,
                       vocab_size=-1):
        vocab = Reseved.with_reserved_types()
        res_stats = {
            r_type.name: term_freqs.pop(r_type.name)
            for r_type in vocab if r_type.name in term_freqs
        }
        if res_stats:
            log.warning(f"Found reserved types in corpus: {res_stats}")
        # Order of trimming techs: 1. coverage, 2. min freqs, 3. size cut off
        unk_count = 0
        if coverage:
            assert 0 < coverage <= 1
            term_freqs, coverage_unk_count = filter_types_coverage(
                term_freqs, coverage=coverage)
            unk_count += coverage_unk_count
        term_freqs = sorted(term_freqs.items(),
                            key=lambda x: x[1],
                            reverse=True)
        if min_freq and min_freq > 1:
            log.info(
                f"Excluding terms with freq < {min_freq}; |freq >= 1|: {len(term_freqs):,}"
            )
            unk_count += sum(f for t, f in term_freqs if f < min_freq)
            term_freqs = [(t, f) for t, f in term_freqs if f >= min_freq]
            log.info(f"|freq >= {min_freq}| : {len(term_freqs):,}")

        if vocab_size > 0 and len(vocab) + len(term_freqs) > vocab_size:
            log.info(f"Truncating vocab at size={vocab_size}")
            unk_count += sum(f
                             for t, f in term_freqs[vocab_size - len(vocab):])
            term_freqs = term_freqs[:vocab_size - len(vocab)]

        # update reserved types with corpus freqs
        for idx, t in enumerate(vocab):
            freq = 0
            if t.name in res_stats:
                freq = res_stats.pop(t.name)
            if idx == Reseved.UNK_IDX:
                freq += unk_count
            if idx in {Reseved.BOS_IDX, Reseved.EOS_IDX, Reseved.CLS_IDX
                       } and line_count:
                freq += line_count
            if freq:
                log.warning(
                    f"Update frequency for reserved type {t} with {freq}")
                vocab[idx] = t.copy(freq=freq)
        vocab += [
            Type(name=name, idx=idx, freq=freq, level=cls.level)
            for idx, (name, freq) in enumerate(term_freqs, start=len(vocab))
        ]
        log.info(f"Total {cls} vocab size {len(vocab):,}")
        return vocab
Ejemplo n.º 18
0
    def learn_subwords(cls,
                       term_freqs: Dict[str, int],
                       vocab_size: int,
                       min_co_evidence: int = DEF_MIN_CO_EV,
                       init_vocab_factory=None) -> List[Type]:
        """
        :param term_freqs:
        :param vocab_size: final vocab size: reserved + chars + user_specified  + merges;
          special case, when `vocab_size=-1` the returned vocab will have just reserved + chars
        :param min_co_evidence: min co evidence for pair merges
        :param char_coverage: percentage of characters to be covered by inital char_freqs
        :param word_min_freq: words below this frequency will be excluded for learning BPE
        :return: List of Type
        """

        log.info(f"Total types: {len(term_freqs)}")
        term_freqs = {
            cls.prepare_word(word): freq
            for word, freq in term_freqs.items()
        }

        char_freqs = coll.defaultdict(int)
        for term, freq in term_freqs.items():
            for ch in term:
                char_freqs[ch] += freq
            """TODO: test this behavior; similar to subword-nmt v0.2
            for ch in term[:-2]:  # skip the last two: ending and the whitespace marker
                char_freqs[ch] += freq
            char_freqs[term[-2:]] += freq  # ending + whitespace marker go together as a single byte
            """
        init_vocab = init_vocab_factory(
            char_freqs)  # create an initial vocabulary of chars
        if vocab_size == -1:
            log.warning(
                f'Since vocab_size={vocab_size}; not going to do any L1 merges'
            )
            log.info(f'Found initial vocab size of {len(init_vocab)}')
            return init_vocab

        return cls._learn_codes(term_freqs,
                                init_vocab,
                                min_co_evidence=min_co_evidence,
                                vocab_size=vocab_size)
Ejemplo n.º 19
0
 def validate_index(self):
     """
     Call this any time to check if the index of uni bi bi_ixs are valid.
     Raises exception on invalid index
     :return:
     """
     # This is code doesnt work with fast but new dirty heap updates
     max_code = max(self.uni)
     max_idx = max(t.idx for t in self.vocab)
     if not (max_code < self.vocab_size and max_code <= max_idx):
         raise ValueError(
             f'Vocab size is {self.vocab_size}, but max_code is {max_code}; max_idx={max_idx}'
         )
     if not len(self.bi) == len(self.bi_ixs):
         raise ValueError(
             f"|bi|={len(self.bi)} and |bi_idxs|={len(self.bi_ixs)} are not same"
         )
     for bigm, freq in self.bi.items():
         if not freq >= 0:
             raise ValueError(
                 f"{bigm} has freq {freq}; expected positive val")
         if not bigm in self.bi_ixs:
             raise ValueError(f"{bigm} exists in bi but not in bi_ixs")
         idx_freq = sum(n.freq for n in self.bi_ixs[bigm])
         if not freq == idx_freq:
             raise ValueError(
                 f"{bigm} has freq={freq} bi but has {idx_freq} bi_ixs refs"
             )
         # less than unigram freqs
         if not freq <= self.uni[bigm[0]]:
             raise ValueError(
                 f"{bigm} has freq={freq} bi but {bigm[0]} has {self.uni[bigm[0]]}"
             )
         if not freq <= self.uni[bigm[1]]:
             raise ValueError(
                 f"{bigm} has freq={freq} bi but {bigm[1]} has {self.uni[bigm[1]]}"
             )
     for uni, freq in self.uni.items():
         if not freq >= 0:
             raise ValueError(
                 f"{uni} has freq={freq}; expected positive value")
     log.info(f"Index is valid")
Ejemplo n.º 20
0
    def create_index(self, data):
        """
        1. encoder all : List[str] -> List[List[Int]]
        2. Count unigram freqs => uni
              2a) measure imbalance
              2b) measure seq lens
        3. Index all bigrams:: bi_ix (a,b) -> Node 
           Oh, yeah, convert all List[Int] to Doubly Linked List
           Create seq_len: List[int] each node should have an integer index of sequence in corpus
           
        4. ask bpe for next merge. gives (a,b) -> t
            a) count how many bigrams in index. 
                insert uni[t] = count
                reduce  uni[a], uni[b] by that count
            b) loop through each indexed node 
               remove b. update a;  update links :: x a b y => x t y
               insert bi_ix[x, t] and bi_ix[t, y]
               using seq number stored on node, decrement length by 1. assert len >= 1
            c) measure imbalance, and avg length
        """
        log.info("Encoding and creating index")
        enc_data  = self.bpe.encode_parallel(data)
        uni: Dict[int, int] = coll.defaultdict(int)
        bi_ixs: Dict[Bigram, Set[LnNode]] = coll.defaultdict(set)
        seq_lens = []
        with tqdm(enumerate(enc_data)) as data_bar:
            for idx, seq in data_bar:
                seq = LnNode.from_seq(seq, data=idx)
                seq_lens.append(len(seq))
                for node in seq:
                    uni[node.val] += 1
                for node in seq[:-1]:
                    bi_ixs[(node.val, node.right.val)].add(node)
                bar_msg = f'MaxRSS={max_RSS()[1]}'
                data_bar.set_postfix_str(bar_msg, refresh=False)

        self.seq_lens = np.array(seq_lens)
        log.info(f"Found {len(self.seq_lens)} sentences")
        for idx, freq in uni.items():
            self.uni[idx] = freq
        self.bi_ixs = bi_ixs
Ejemplo n.º 21
0
    def _learn_codes(cls,
                     term_freqs: Dict[str, int],
                     vocab: List[Type],
                     vocab_size: int,
                     init_list: List[str] = None,
                     min_co_evidence: int = DEF_MIN_CO_EV) -> List[Type]:
        """
        :param term_freqs: words types and frequencies
        :param vocab: initial vocab; usually reserved and alphabet
        :param vocab_size: desired vocabulary size
        :param init_list: any reserved words
        :return: List[str] word pieces
        """

        vocab = copy.copy(vocab)
        log.info(
            f"Found {len(term_freqs)} types before splitting; initial vocab {len(vocab)}"
        )
        if init_list:
            log.info(f'Adding {len(init_list)} types to the initial vocab')
            assert not any(
                ' ' in w
                for w in init_list), 'spaces not allowed in init_list words'
            vocab += [
                Type(cls.prepare_word(w), level=Level.user, idx=idx, freq=0)
                for idx, w in enumerate(init_list, start=len(vocab))
            ]

        rev_idx: Dict[str, int] = {word.name: word.idx for word in vocab}
        assert len(rev_idx) == len(vocab)  # one to one map
        assert vocab_size > len(vocab), f'vocab_size={vocab_size} is too small;' \
            f' found {len(vocab)} in the init vocab! Set a value larger than {len(vocab)}'

        seqs_freqs = cls._make_idxs(rev_idx, term_freqs)
        learner = BPELearn(seqs_freqs, vocab=vocab)
        final_vocab = learner.learn_codes(n_merges=vocab_size - len(vocab),
                                          min_co_evidence=min_co_evidence,
                                          code_level=Level.subword)
        return final_vocab
Ejemplo n.º 22
0
    def write_out(cls, table: List['Type'], out: Union[Path, str, TextIO]):

        if isinstance(out, Path) or isinstance(out, str):
            wrtr = open(out, 'w', encoding='utf8', errors='ignore')
        else:
            wrtr = out

        levels = dict(coll.Counter(v.level for v in table))
        max_level = max(levels.keys())
        meta = dict(total=len(table),
                    version=__version__,
                    levels=levels,
                    max_level=max_level,
                    created=str(datetime.now()))
        meta = json.dumps(meta)
        wrtr.write(f"#{meta}\n")
        for i, item in enumerate(table):
            assert i == item.idx, f'{item} expected index {i}'
            wrtr.write(item.format() + '\n')
        if wrtr is not out:
            wrtr.close()
        log.info(f"Wrote {len(table)} to {wrtr.name}")
Ejemplo n.º 23
0
    def create_index(self, seqs):
        log.info(
            "Going to build corpus stats index; This might take lot of time and memory"
        )
        n_seqs, n_ignored, n_replaced, bar_msg = 0, 0, 0, ''
        with tqdm(enumerate(seqs),
                  unit='seqs',
                  dynamic_ncols=True,
                  mininterval=1) as data_bar:
            for idx, seq in data_bar:
                freq = 1  # default = 1 freq
                if isinstance(seq, tuple):  # if freq is available
                    seq, freq = seq

                n_seqs += 1
                if idx == 0:  # basic validation
                    assert isinstance(seq, list)  # first sequence, tokenized
                    assert isinstance(
                        seq[0],
                        int)  # sequence's item, should be an int or codepoint
                if not seq:
                    log.warning(f"Skipping empty sequence at idx {idx + 1}")
                    continue

                nodes = LnNode.from_seq(seq, freq=freq)
                assert len(seq) == len(nodes)

                for i in range(len(seq) - 1):  # the last position left out
                    bigm = (seq[i], seq[i + 1])
                    self.bi[bigm] += freq
                    assert nodes[i] not in self.bi_ixs[bigm]
                    self.bi_ixs[bigm].add(nodes[i])  # bigm found at node i
                    self.uni[seq[i]] += freq
                self.uni[seq[
                    -1]] += freq  # the last unigram count; not covered in the above loop
                bar_msg = f'MaxRSS={max_RSS()[1]}'
                data_bar.set_postfix_str(bar_msg, refresh=False)
        log.info(f"Created index; {bar_msg}")
Ejemplo n.º 24
0
    def learn_codes(self, n_merges: int, min_co_evidence,
                    code_level: int) -> List[Type]:
        """
        :param n_merges: how many more merges
        :param min_co_evidence: min evidence (co-occurrence frequency);
         causes early stop upon failure
        :param code_level: what level to use for new code types created during merge
            for instance level=1 for word bpe; level=2 for seq bpe
        :return:
        """
        uni, bi_ixs = self.uni, self.bi_ixs
        heap = MaxHeap(self.bi)
        heap_dirty = coll.defaultdict(
            int)  # subtractions aren't updated in max-heap, they are here
        vocab = self.vocab
        for i in range(n_merges):
            # Using MaxHeap for faster lookup of max. But heap gets a bit dirty, so a bit of cleanup
            max_pair, pair_freq = heap.pop()
            while max_pair in heap_dirty:  # clean all max [airs until a clean value
                freq_update = heap_dirty.pop(max_pair)
                assert freq_update < 0  # only decrements are valid. increments make this wrong
                corr_freq = pair_freq + freq_update  # correct value
                assert corr_freq >= 0, f'{max_pair}:{pair_freq}, Δ={freq_update} = {corr_freq}'
                if corr_freq > 0:  # exclude zero count
                    heap.push(max_pair, corr_freq)
                max_pair, pair_freq = heap.pop()

            # here the  actual loop begins
            if pair_freq < min_co_evidence:
                log.warning(f"Early stop; max evidence found is {pair_freq} "
                            f"but min required is {min_co_evidence}")
                break

            new_type_idx = len(vocab)
            a, b = max_pair
            log.info(
                f"{(100 * i / n_merges):.2f}% :: {new_type_idx} || {a:4}:{uni[a]:5}"
                f" || {b:4}:{uni[b]:5} || {pair_freq:,} || {vocab[a].name} {vocab[b].name}"
            )

            # code -> bigram   (flatten out bigram;  resolve interim codes
            new_type = Type(vocab[a].name + vocab[b].name,
                            idx=new_type_idx,
                            freq=pair_freq,
                            level=code_level,
                            kids=(vocab[a], vocab[b]))
            vocab.append(new_type)

            # updates: update bigram and unigram counts
            uni[new_type_idx] = pair_freq  # this bigram is now a new unigram
            # unigram counts drop ; since some of their bigrams are removed
            uni[a] -= pair_freq
            uni[b] -= pair_freq
            heap_deltas = coll.defaultdict(int)
            update_nodes = bi_ixs.pop(max_pair)  # also removed from bi_ixs
            for node in update_nodes:
                # -- x a b y --
                x_node, b_node = node.left, node.right
                if node.is_unlinked or (a == b and new_type.idx
                                        in (node.val, b_node.val)):
                    # this happens in the cases like "x a a a a y"
                    uni[a] += node.freq
                    uni[b] += node.freq
                    uni[new_type.idx] -= node.freq
                    continue

                y_node = b_node.right
                dirty = node.val != a or b_node.val != b  # check that the linked list is proper
                if dirty:
                    log.warning(
                        f'Expected {a, b} but found {node.val, b_node.val}'
                        f'\n {node, b_node}'
                        f'\n--{vocab[a].signature()} =='
                        f' {vocab[node.val].signature() if node.val != a else "OK"}'
                        f'\n--{vocab[b].signature()} =='
                        f' {vocab[b_node.val].signature() if b_node.val != b else "OK"}'
                    )
                    log.warning(
                        f"a={a}, b={b} || a_node={node}, b_node={b_node}")
                assert not dirty
                assert node.freq == b_node.freq

                # update : x a b y => x R y
                b_node.delete(
                    unlink=True
                )  # delete() takes care of linking a → y and a ← y
                new_node = node  # reuse a node as new_node/R
                new_node.val = new_type_idx  # reuse a as new_node/R
                # Note: the above edits to a and b nodes do-not/should-not change __hash__

                if x_node:
                    # remove (x_node_val, a) from bi and bi_ixs
                    heap_deltas[(x_node.val, a)] -= x_node.freq
                    if bi_ixs.get((x_node.val, a)):
                        # not sure why 'if' needed here;
                        bi_ixs[(x_node.val, a)].remove(x_node)

                    # add (x_node_val, R) to bi and bi_ixs
                    heap_deltas[(x_node.val, new_type_idx)] += x_node.freq
                    bi_ixs[(x_node.val, new_type_idx)].add(x_node)
                if y_node:
                    # remove (b, y_node.val) from bi and bi_ixs
                    heap_deltas[(b, y_node.val)] -= b_node.freq
                    if bi_ixs.get((b, y_node.val)):
                        # not sure why 'if' needed here;
                        bi_ixs[(b, y_node.val)].remove(b_node)

                    # add (R, y_node.val) to bi and bi_ixs
                    heap_deltas[(new_type_idx, y_node.val)] += b_node.freq
                    bi_ixs[(new_type_idx, y_node.val)].add(new_node)

            # however; the counts shouldn't go negative
            assert uni[a] >= 0
            assert uni[b] >= 0

            for pair, delta in heap_deltas.items():
                if delta > 0:  # these are new insertions, and they can go directly to heap
                    assert new_type_idx in pair
                    heap.push(pair, delta)
                elif delta < 0:  # one of those subtractions, which cant be directly updated
                    assert new_type_idx not in pair
                    heap_dirty[pair] += delta

        return vocab
Ejemplo n.º 25
0
    def shrink_vocab(self,
                     files: List[Path],
                     min_freq: int,
                     save_at: Optional[Path] = None) -> List[int]:
        """
        :param files:
        :param min_freq:
        :param save_at:
        :return:
        """
        """"
- Accept a list of files
- compute term frequencies
- Eliminate types with zero counts
- Preserve reserved types even if they have zero counts
- Save the resulting model at a given file path
- Return index mapping between old and new, so we can go back to model and shrink embedding tables
        """
        from tqdm import tqdm
        freqs = coll.Counter()
        for file in files:
            log.info(f'Computing term frequencies from {file}')
            with IO.reader(file) as lines:
                for line in tqdm(lines):
                    freqs.update(self.encode(line))
        assert len(self.table) > max(freqs.keys())
        removals = [False] * len(self.table)
        for idx, typ in enumerate(self.table):
            if typ.level == Level.reserved:
                continue  # i.e. don't remove
            removals[idx] = freqs[
                idx] < min_freq  # remove if min_freq threshold not met

        # now make sure to preserve all the sub pieces leading to the pieces that retain
        for idx in range(len(self.table) - 1, -1, -1):
            combo = self.table[idx]
            assert combo.idx == idx
            if not removals[combo.idx] and combo.kids:  #
                for piece in combo.kids:
                    if removals[piece.idx]:
                        removals[piece.idx] = False  # dont remove this piece,

        mapping = []
        for idx, (is_remove, typ) in enumerate(zip(removals, self.table)):
            assert idx == typ.idx
            if is_remove:
                continue
            mapping.append(idx)

        log.info(
            f"Shrinking vocab tables: {len(self.table)} --> {len(mapping)} ")
        rev_mapping = {
            old_idx: new_idx
            for new_idx, old_idx in enumerate(mapping)
        }
        old_table = self.table
        new_table = []
        for new_idx, old_idx in enumerate(mapping):
            assert len(new_table) == new_idx
            old_t = old_table[old_idx]
            new_kids = [new_table[rev_mapping[k.idx]]
                        for k in old_t.kids] if old_t.kids else None
            new_t = old_t.copy(idx=new_idx, kids=new_kids)
            new_table.append(new_t)
        if save_at:
            Type.write_out(new_table, out=save_at)
        return mapping
Ejemplo n.º 26
0
    def track_merges(self, do_log=False):
        uni = self.uni
        bi_ixs = self.bi_ixs
        seq_lens = self.seq_lens
        res = []
        while self.bpe.merges_remaining() > 0:
            """
            ask bpe for next merge. gives (a,b) -> t
            a) count how many bigrams in index. 
                insert uni[t] = count
                reduce  uni[a], uni[b] by that count
            b) loop through each indexed node 
               remove b. update a;  update links :: x a b y => x t y
               insert bi_ix[x, t] and bi_ix[t, y]
               using seq number stored on node, decrement length by 1. assert len >= 1
            c) measure imbalance, and avg length
            """
            k = len(self.bpe)
            emd = self.earth_mov_dist(self.uni[:k])
            mu_len = self.mean_seq_len()
            res.append((k, emd, mu_len))
            if do_log:
                log.info(f"{k}: imbalance={emd:g} mean_seq_len={mu_len:g}."
                         f"Next: {self.bpe.peek_merge().signature()}")
            new_type = self.bpe.merge()
            a_type, b_type = new_type.kids
            a, b = a_type.idx, b_type.idx
            update_nodes = bi_ixs.pop((a, b), None)
            if update_nodes:
                freq = sum(u.freq for u in update_nodes)
                uni[new_type.idx] = freq
                uni[a] -= freq
                uni[b] -= freq
                for u in update_nodes:
                    b_node = u.right
                    if u.is_unlinked or (a == b and new_type.idx in (u.val, b_node.val)):
                        # happens due to repeats like X A A A A Y
                        uni[new_type.idx] -= u.freq
                        uni[a] += u.freq
                        uni[b] -= u.freq
                        continue
                    assert u.val == a
                    assert b_node.val == b, f'expected {a, b} found {u.val, b_node.val} at {b_node}'
                    u.val = new_type.idx
                    b_node.delete(unlink=True)

                    seq_idx = u.data
                    seq_lens[seq_idx] -= 1
                    assert seq_lens[seq_idx] > 0

                    if u.left:
                        if  bi_ixs.get((u.left.val, a)):
                            bi_ixs[(u.left.val, a)].remove(u.left)
                        bi_ixs[(u.left.val, u.val)].add(u.left)
                    if u.right:
                        if bi_ixs.get((b, u.right.val)):
                            bi_ixs[(b, u.right.val)].remove(b_node)
                        bi_ixs[(u.val, u.right.val)].add(u)
        k = len(self.bpe)
        emd = self.earth_mov_dist(self.uni[:k])
        mu_len = self.mean_seq_len()
        res.append((k, emd, mu_len))
        log.info(f"{k}: imbalance={emd:g} mean_seq_len={mu_len:g}.")
        return res