コード例 #1
0
ファイル: sort.py プロジェクト: masonreznov/transformer
def handle(srcfs, srcft, tgtfs, tgtft, max_len=256, remove_same=False, shuf=True, max_remove=False):

	_max_len = max(1, max_len - 2)

	data = {}

	with open(srcfs, "rb") as fs, open(srcft, "rb") as ft:
		for ls, lt in zip(fs, ft):
			ls, lt = ls.strip(), lt.strip()
			if ls and lt:
				ls, slen = clean_liststr_lentok(ls.decode("utf-8").split())
				lt, tlen = clean_liststr_lentok(lt.decode("utf-8").split())
				if (slen <= _max_len) and (tlen <= _max_len):
					lgth = slen + tlen
					data = dict_insert_list(data, (ls, lt,), lgth, tlen)

	ens = "\n".encode("utf-8")

	with open(tgtfs, "wb") as fs, open(tgtft, "wb") as ft:
		for tmp in iter_dict_sort(data):
			ls, lt = zip(*tmp)
			if len(ls) > 1:
				if remove_same:
					ls, lt = maxfreq_filter(ls, lt, max_remove)
				if shuf:
					ls, lt = shuffle_pair(ls, lt)
			fs.write("\n".join(ls).encode("utf-8"))
			fs.write(ens)
			ft.write("\n".join(lt).encode("utf-8"))
			ft.write(ens)
コード例 #2
0
ファイル: merge.py プロジェクト: hfxunlp/transformer
	def write_data(data, wfl, ens, shuf=True, max_remove=False):

		lines = zip(*data)
		if len(data) > 1:
			if max_remove:
				lines = maxfreq_filter(*lines)
			if shuf:
				lines = shuffle_pair(*lines)
		for du, f in zip(lines, wfl):
			f.write(ens.join(du))
			f.write(ens)
コード例 #3
0
ファイル: merge.py プロジェクト: masonreznov/transformer
    def write_data(data, fs, ft, ens, rsame, shuf, mclean):

        for tmp in iter_dict_sort(data):
            ls, lt = zip(*tmp)
            if len(ls) > 1:
                if rsame:
                    ls, lt = maxfreq_filter(ls, lt, mclean)
                if shuf:
                    ls, lt = shuffle_pair(ls, lt)
            fs.write("\n".join(ls).encode("utf-8"))
            fs.write(ens)
            ft.write("\n".join(lt).encode("utf-8"))
            ft.write(ens)
コード例 #4
0
ファイル: sort.py プロジェクト: hfxunlp/transformer
def handle(srcfl,
           tgtfl,
           max_len=256,
           remove_same=False,
           shuf=True,
           max_remove=False):

    _max_len = max(1, max_len - 2)

    _insert_func = dict_insert_set if remove_same and (
        not max_remove) else dict_insert_list
    data = {}

    with FileList(srcfl, "rb") as fl:
        for lines in zip(*fl):
            lines = [line.strip() for line in lines]
            if all(lines):
                lines, lens = zip(*[
                    clean_liststr_lentok(line.decode("utf-8").split())
                    for line in lines
                ])
                if all_le(lens, max_len):
                    lgth = sum(lens)
                    ls = lines[0]
                    data = _insert_func(
                        data, tuple(line.encode("utf-8") for line in lines),
                        ls[:ls.find(" ")], lgth, *reversed(lens[1:]))

    ens = "\n".encode("utf-8")

    with FileList(tgtfl, "wb") as fl:
        for tmp in iter_dict_sort(data):
            lines = zip(*tmp)
            if len(tmp) > 1:
                if max_remove:
                    lines = maxfreq_filter(*lines)
                if shuf:
                    lines = shuffle_pair(*lines)
            for du, f in zip(lines, fl):
                f.write(ens.join(du))
                f.write(ens)
コード例 #5
0
ファイル: sort.py プロジェクト: hfxunlp/transformer
def handle(srcfl,
           tgtfl,
           max_len=256,
           remove_same=False,
           shuf=True,
           max_remove=False):

    _max_len = max(1, max_len - 2)

    _insert_func = dict_insert_set if remove_same and (
        not max_remove) else dict_insert_list
    data = {}
    cache = []

    with FileList(srcfl, "rb") as fl:
        for lines in zip(*fl):
            lines = [line.strip() for line in lines]
            if all(lines):
                lines, lens = zip(*[
                    clean_liststr_lentok(line.decode("utf-8").split())
                    for line in lines
                ])
                if all_le(lens, max_len):
                    lgth = sum(lens)
                    cache.append((
                        lines,
                        lens,
                    ))
                else:
                    if cache:
                        nsent = len(cache)
                        lines, lens = zip(*cache)
                        lines = zip(*lines)
                        lens = zip(*lens)
                        mxlens = [max(mu) for mu in lens]
                        slens = [sum(mu) for mu in lens]
                        lines = tuple("\n".join(lu) for lu in lines)
                        data = _insert_func(
                            data,
                            tuple(line.encode("utf-8") for line in lines),
                            nsent, sum(mxlens), *reversed(mxlens[1:]),
                            sum(slens), *reversed(slens[1:]))
                        cache = []
            else:
                if cache:
                    nsent = len(cache)
                    lines, lens = zip(*cache)
                    lines = zip(*lines)
                    lens = zip(*lens)
                    mxlens = [max(mu) for mu in lens]
                    slens = [sum(mu) for mu in lens]
                    lines = tuple("\n".join(lu) for lu in lines)
                    data = _insert_func(
                        data, tuple(line.encode("utf-8") for line in lines),
                        nsent, sum(mxlens), *reversed(mxlens[1:]), sum(slens),
                        *reversed(slens[1:]))
                    cache = []

    ens = "\n\n".encode("utf-8")
    with FileList(tgtfl, "wb") as fl:
        for tmp in iter_dict_sort(data):
            lines = zip(*tmp)
            if len(tmp) > 1:
                if max_remove:
                    lines = maxfreq_filter(*lines)
                if shuf:
                    lines = shuffle_pair(*lines)
            for du, f in zip(lines, fl):
                f.write(ens.join(du))
                f.write(ens)
コード例 #6
0
def handle(srcfs,
           srcft,
           tgtfs,
           tgtft,
           remove_same=False,
           shuf=True,
           max_remove=False):

    data = {}
    cache = []
    mxtoks = mxtokt = ntoks = ntokt = 0

    with open(srcfs, "rb") as fs, open(srcft, "rb") as ft:
        for ls, lt in zip(fs, ft):
            ls, lt = ls.strip(), lt.strip()
            if ls and lt:
                ls, slen = clean_liststr_lentok(ls.decode("utf-8").split())
                lt, tlen = clean_liststr_lentok(lt.decode("utf-8").split())
                cache.append((
                    ls,
                    lt,
                ))
                if slen > mxtoks:
                    mxtoks = slen
                if tlen > mxtokt:
                    mxtokt = tlen
                ntoks += slen
                ntokt += tlen
            else:
                if cache:
                    nsent = len(cache)
                    ls, lt = zip(*cache)
                    _tmp = (
                        "\n".join(ls),
                        "\n".join(lt),
                    )
                    data = dict_insert_set(data, _tmp, nsent, mxtoks + mxtokt,
                                           mxtokt, ntoks + ntokt, ntokt)
                    cache = []
                    mxtoks = mxtokt = ntoks = ntokt = 0
        if cache:
            nsent = len(cache)
            ls, lt = zip(*cache)
            _tmp = (
                "\n".join(ls),
                "\n".join(lt),
            )
            data = dict_insert_set(data, _tmp, nsent, mxtoks + mxtokt, mxtokt,
                                   ntoks + ntokt, ntokt)
            cache = []
            mxtoks = mxtokt = ntoks = ntokt = 0

    ens = "\n\n".encode("utf-8")

    with open(tgtfs, "wb") as fs, open(tgtft, "wb") as ft:
        for tmp in iter_dict_sort(data):
            ls, lt = zip(*tmp)
            if len(ls) > 1:
                if remove_same:
                    ls, lt = maxfreq_filter(ls, lt, max_remove)
                if shuf:
                    ls, lt = shuffle_pair(ls, lt)
            fs.write("\n\n".join(ls).encode("utf-8"))
            fs.write(ens)
            ft.write("\n\n".join(lt).encode("utf-8"))
            ft.write(ens)