コード例 #1
0
def handle(h5f, vcbsf, vcbtf, rsfs, rsft):

    td = h5py.File(h5f, "r")

    ntest = td["ndata"][:].item()
    nword = td["nwordi"][:].tolist()
    nwordi = nword[0]
    vcbs, nwords = ldvocab(vcbsf)
    vcbs = reverse_dict(vcbs)
    vcbt, nwordt = ldvocab(vcbtf)
    vcbt = reverse_dict(vcbt)
    src_grp, tgt_grp = td["src"], td["tgt"]

    ens = "\n".encode("utf-8")

    with open(rsfs, "wb") as fs:
        with open(rsft, "wb") as ft:
            for i in range(ntest):
                curid = str(i)
                curd = torch.from_numpy(src_grp[curid][:]).tolist()
                md = []
                for iu in curd:
                    md.append(" ".join([vcbs.get(i) for i in iu]))
                fs.write("\n".join(md).encode("utf-8"))
                fs.write(ens)
                curd = torch.from_numpy(tgt_grp[curid][:]).tolist()
                md = []
                for tu in curd:
                    md.append(" ".join([vcbt.get(i) for i in tu]))
                ft.write("\n".join(md).encode("utf-8"))
                ft.write(ens)

    td.close()
コード例 #2
0
def handle(common, src, tgt, srcm, rsm, minfreq=False, vsize=False):

    vcbc, nwordf = ldvocab(common,
                           minf=minfreq,
                           omit_vsize=vsize,
                           vanilla=False)

    if src == common:
        src_indices = None
    else:
        vcbw, nword = ldvocab(src,
                              minf=minfreq,
                              omit_vsize=vsize,
                              vanilla=False)
        vcbw = reverse_dict(vcbw)
        src_indices = torch.tensor(
            [vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long)
    if tgt == common:
        tgt_indices = None
    else:
        vcbw, nword = ldvocab(tgt,
                              minf=minfreq,
                              omit_vsize=vsize,
                              vanilla=False)
        vcbw = reverse_dict(vcbw)
        tgt_indices = torch.tensor(
            [vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long)

    mymodel = NMT(cnfg.isize, nwordf, nwordf, cnfg.nlayer, cnfg.ff_hsize,
                  cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
                  cache_len_default, cnfg.attn_hsize, cnfg.norm_output,
                  cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
    mymodel = load_model_cpu(srcm, mymodel)
    mymodel.update_vocab(src_indices=src_indices, tgt_indices=tgt_indices)
    save_model(mymodel, rsm, sub_module=False, h5args=h5zipargs)
コード例 #3
0
ファイル: ext_emb.py プロジェクト: masonreznov/transformer
def handle(vcbf, embf, rsf):

    vcb, nwd = ldvocab(vcbf)
    emb = load_emb_txt(vcb, embf)
    unkemb = emb.get("<unk>", torch.zeros(emb[list(emb.keys())[0]].size(0)))
    vcb = reverse_dict(vcb)
    rs = []
    for i in range(nwd):
        rs.append(emb.get(vcb[i], unkemb))
    h5save(torch.stack(rs, 0), rsf)
コード例 #4
0
	def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mulgpu=True, bsize=64, maxpad=16, maxpart=4, maxtoken=1536, minfreq = False, vsize = False):

		vcbi, nwordi = ldvocab(fvocab_i, minfreq, vsize)
		vcbt, nwordt = ldvocab(fvocab_t, minfreq, vsize)
		self.vcbi, self.vcbt = vcbi, reverse_dict(vcbt)

		if expand_for_mulgpu:
			self.bsize = bsize * minbsize
			self.maxtoken = maxtoken * minbsize
		else:
			self.bsize = bsize
			self.maxtoken = maxtoken
		self.maxpad = maxpad
		self.maxpart = maxpart
		self.minbsize = minbsize

		if isinstance(modelfs, (list, tuple)):
			models = []
			for modelf in modelfs:
				tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)

				tmp = load_model_cpu(modelf, tmp)
				tmp.apply(load_fixing)

				models.append(tmp)
			model = Ensemble(models)

		else:
			model = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)

			model = load_model_cpu(modelfs, model)
			model.apply(load_fixing)

		model.eval()

		self.use_cuda, self.cuda_device, cuda_devices, self.multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding)

		if self.use_cuda:
			model.to(self.cuda_device)
			if self.multi_gpu:
				model = DataParallelMT(model, device_ids=cuda_devices, output_device=self.cuda_device.index, host_replicate=True, gather_output=False)
		self.use_amp = cnfg.use_amp and self.use_cuda

		self.beam_size = cnfg.beam_size

		self.length_penalty = cnfg.length_penalty
		self.net = model
コード例 #5
0
from utils.fmt.base import ldvocab, reverse_dict, init_vocab, sos_id, eos_id
from utils.fmt.base4torch import parse_cuda_decode


def load_fixing(module):

    if hasattr(module, "fix_load"):
        module.fix_load()


td = h5py.File(sys.argv[1], "r")

ntest = td["ndata"][:].item()
nwordi = td["nword"][:].tolist()[0]
vcbt, nwordt = ldvocab(sys.argv[3])
vcbt = reverse_dict(vcbt)

mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize,
              cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
              cache_len_default, cnfg.attn_hsize, cnfg.norm_output,
              cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd)

mymodel = load_model_cpu(sys.argv[2], mymodel)
mymodel.apply(load_fixing)

mymodel.eval()

enc, trans, classifier = mymodel.enc, mymodel.dec.trans, mymodel.dec.classifier

use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda_decode(
    cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding)