コード例 #1
0
ファイル: vocab.py プロジェクト: masonreznov/transformer
def handle(srcf, rsf, vsize=65532):

    vocab = {}

    with open(srcf, "rb") as f:
        for line in f:
            tmp = line.strip()
            if tmp:
                for token in clean_list_iter(tmp.decode("utf-8").split()):
                    vocab[token] = vocab.get(token, 0) + 1

    r_vocab = {}
    for k, v in vocab.items():
        if v not in r_vocab:
            r_vocab[v] = [str(v), k]
        else:
            r_vocab[v].append(k)

    freqs = list(r_vocab.keys())
    freqs.sort(reverse=True)

    ens = "\n".encode("utf-8")
    remain = vsize
    with open(rsf, "wb") as f:
        for freq in freqs:
            cdata = r_vocab[freq]
            ndata = len(cdata) - 1
            if remain < ndata:
                cdata = cdata[:remain + 1]
                ndata = remain
            f.write(" ".join(cdata).encode("utf-8"))
            f.write(ens)
            remain -= ndata
            if remain <= 0:
                break
コード例 #2
0
def handle(srcfl, rsf, rslangf, vsize=65532):

    vocab = {}
    lang_vocab = {}

    curid = 0
    for srcf in srcfl:
        if srcf == "--target":
            break
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    tokens = clean_list(tmp.decode("utf-8").split())
                    for token in tokens[1:]:
                        vocab[token] = vocab.get(token, 0) + 1
                    token = tokens[0]
                    lang_vocab[token] = lang_vocab.get(token, 0) + 1
        curid += 1

    for srcf in srcfl[curid + 1:]:
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    for token in clean_list_iter(tmp.decode("utf-8").split()):
                        vocab[token] = vocab.get(token, 0) + 1

    save_vocab(vocab, rsf, omit_vsize=vsize)
    save_vocab(lang_vocab, rslangf, omit_vsize=False)
コード例 #3
0
def handle(srcf, rsf, vsize=65532):

    vocab = {}

    with open(srcf, "rb") as f:
        for line in f:
            tmp = line.strip()
            if tmp:
                for token in clean_list_iter(tmp.decode("utf-8").split()):
                    vocab[token] = vocab.get(token, 0) + 1

    save_vocab(vocab, rsf, omit_vsize=vsize)
コード例 #4
0
def collect(fl):

    vcb = set()
    for srcf in fl:
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    for token in clean_list_iter(tmp.decode("utf-8").split()):
                        if not token in vcb:
                            vcb.add(token)

    return vcb
コード例 #5
0
def handle(srcfl):

    global has_unk

    vocab = set()

    for srcf in srcfl:
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    for token in clean_list_iter(tmp.decode("utf-8").split()):
                        if not token in vocab:
                            vocab.add(token)

    nvcb = len(vocab)
    nvcb += 4 if has_unk else 3

    print("The size of the vocabulary is: %d (with special tokens counted)" %
          (nvcb))