示例#1
0
def handle(cnfg, srcmtf, decf, rsf):

    with h5File(cnfg.dev_data, "r") as tdf:
        nwordi, nwordt = tdf["nword"][:].tolist()

    mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize,
                  cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
                  cache_len_default, cnfg.attn_hsize, cnfg.norm_output,
                  cnfg.bindDecoderEmb, cnfg.forbidden_indexes,
                  cnfg.num_layer_fwd)
    init_model_params(mymodel)
    _tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize,
                    cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
                    cache_len_default, cnfg.attn_hsize, cnfg.norm_output,
                    cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
    _tmpm = init_model_params(_tmpm)
    _tmpm = load_model_cpu(srcmtf, _tmpm)
    mymodel.load_base(_tmpm)
    mymodel.dec = load_model_cpu(decf, mymodel.dec)
    if cnfg.share_emb:
        mymodel.dec.wemb.weight = _tmpm.enc.wemb.weight
    if cnfg.bindDecoderEmb:
        mymodel.dec.classifier.weight = mymodel.dec.wemb.weight
    _tmpm = None

    save_model(mymodel, rsf, sub_module=False, h5args=h5zipargs)
示例#2
0
ntrain = td["ndata"][:].item()
nvalid = vd["ndata"][:].item()
nword = td["nword"][:].tolist()
nwordi, nwordt = nword[0], nword[-1]

tl = [str(i) for i in range(ntrain)]

logger.info("Design models with seed: %d" % torch.initial_seed())
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize,
              cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead,
              cache_len_default, cnfg.attn_hsize, cnfg.norm_output,
              cnfg.bindDecoderEmb, cnfg.forbidden_indexes)

fine_tune_m = cnfg.fine_tune_m

mymodel = init_model_params(mymodel)
mymodel.apply(init_fixing)
if fine_tune_m is not None:
    logger.info("Load pre-trained model from: " + fine_tune_m)
    mymodel = load_model_cpu(fine_tune_m, mymodel)
    mymodel.apply(load_fixing)

lossf = LabelSmoothingLoss(nwordt,
                           cnfg.label_smoothing,
                           ignore_index=pad_id,
                           reduction='sum',
                           forbidden_index=cnfg.forbidden_indexes)

if cnfg.src_emb is not None:
    logger.info("Load source embedding from: " + cnfg.src_emb)
    load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi,
示例#3
0
td = h5py.File(cnfg.train_data, "r")
vd = h5py.File(cnfg.dev_data, "r")

ntrain = td["ndata"][:].item()
nvalid = vd["ndata"][:].item()
nword = td["nword"][:].tolist()
nwordi, nwordt = nword[0], nword[-1]

tl = [str(i) for i in range(ntrain)]

logger.info("Design models with seed: %d" % torch.initial_seed())
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd)

fine_tune_m = cnfg.fine_tune_m

mymodel = init_model_params(mymodel)
mymodel.apply(init_fixing)
if fine_tune_m is not None:
	logger.info("Load pre-trained model from: " + fine_tune_m)
	_tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
	_tmpm = init_model_params(_tmpm)
	_tmpm.apply(init_fixing)
	_tmpm = load_model_cpu(fine_tune_m, _tmpm)
	freeze_module(_tmpm)
	mymodel.load_base(_tmpm)
	_tmpm = None

if cnfg.probe_remove_self:
	mymodel.dec.nets[-1].perform_self_attn = False
elif cnfg.probe_remove_cross:
	mymodel.dec.nets[-1].perform_cross_attn = False