def bulid_model(opt, src_vocab_size, tgt_vocab_size):
    last_epoch = 0
    if opt.train_from != '':
        print("=> loading checkpoint '{}'".format(opt.train_from))
        checkpoint = torch.load(opt.train_from)
        # model_path = opt.train_from
    else:
        # create new folder.
        # t = datetime.datetime.now()
        # cur_time = '%s-%s-%s' %(t.day, t.month, t.hour)
        # cur_time = time.strftime("%Y_%m_%d_%H_%M", time.localtime())
        # save_path = os.path.join(opt.save_model, cur_time)
        save_path = opt.save_model
        print("save_path", save_path)
        try:
            os.makedirs(save_path)
        except OSError:
            pass
    lr = opt.lr
    print("rnn size", opt.rnn_size)
    src_netE_att = _netE_att(opt.rnn_size, src_vocab_size, opt.dropout)
    tgt_netE_att = _netE_att(opt.rnn_size,
                             tgt_vocab_size,
                             opt.dropout,
                             is_target=True)

    src_netW = _netW(src_vocab_size, opt.rnn_size, opt.dropout, name="src")
    tgt_netW = _netW(tgt_vocab_size, opt.rnn_size, opt.dropout, name="tgt")
    print("src_netW", src_vocab_size, src_netW.word_embed)
    print("tgt_netW", tgt_vocab_size, tgt_netW.word_embed)

    critD = model.nPairLoss(opt.rnn_size, opt.margin)
    # exit()

    if opt.train_from != '':  # load the pre-trained model.
        src_netW.load_state_dict(checkpoint['src_netW'])
        tgt_netW.load_state_dict(checkpoint['tgt_netW'])
        src_netE_att.load_state_dict(checkpoint['src_netE_att'])
        tgt_netE_att.load_state_dict(checkpoint['tgt_netE_att'])
        last_epoch = checkpoint['epoch']
        lr = checkpoint['lr']

    if opt.cuda:
        tgt_netW.cuda(), src_netW.cuda(), src_netE_att.cuda(
        ), tgt_netE_att.cuda(), critD.cuda()

    return tgt_netW, src_netW, src_netE_att, tgt_netE_att, critD, lr, last_epoch
Пример #2
0
####################################################################################
# Build the Model
####################################################################################

n_words = dataset_val.vocab_size
ques_length = dataset_val.ques_length
ans_length = dataset_val.ans_length + 1
his_length = ques_length+dataset_val.ans_length
itow = dataset_val.itow
img_feat_size = 512

netE = _netE(opt.model, opt.ninp, opt.nhid, opt.nlayers, opt.dropout, img_feat_size)

netW = model._netW(n_words, opt.ninp, opt.dropout)
netD = model._netD(opt.model, opt.ninp, opt.nhid, opt.nlayers, n_words, opt.dropout)
critD = model.nPairLoss(opt.nhid, 2)

if opt.model_path != '': # load the pre-trained model.
    netW.load_state_dict(checkpoint['netW'])
    netE.load_state_dict(checkpoint['netE'])
    netD.load_state_dict(checkpoint['netD'])
    print('Loading model Success!')

if opt.cuda: # ship to cuda, if has GPU
    netW.cuda(), netE.cuda(), netD.cuda()
    critD.cuda()

n_neg = 100
####################################################################################
# Some Functions
####################################################################################
Пример #3
0
####################################################################################
# Build the Model
####################################################################################
n_neg = opt.negative_sample
vocab_size = dataset.vocab_size
ques_length = dataset.ques_length
ans_length = dataset.ans_length + 1
his_length = dataset.ans_length + dataset.ques_length
itow = dataset.itow
img_feat_size = 512

netE = _netE(opt.model, opt.ninp, opt.nhid, opt.nlayers, opt.dropout, img_feat_size)
netW = model._netW(vocab_size, opt.ninp, opt.dropout)
netD = model._netD(opt.model, opt.ninp, opt.nhid, opt.nlayers, vocab_size, opt.dropout)
critD =model.nPairLoss(opt.ninp, opt.margin)

if opt.model_path != '': # load the pre-trained model.
    netW.load_state_dict(checkpoint['netW'])
    netE.load_state_dict(checkpoint['netE'])
    netD.load_state_dict(checkpoint['netD'])

if opt.cuda: # ship to cuda, if has GPU
    netW.cuda(), netE.cuda(),
    netD.cuda(), critD.cuda()

####################################################################################
# training model
####################################################################################
def train(epoch):
    netW.train()
Пример #4
0
####################################################################################
# Build the Model
####################################################################################
n_neg = opt.negative_sample
vocab_size = dataset.vocab_size
ques_length = dataset.ques_length
ans_length = dataset.ans_length + 1
his_length = dataset.ans_length + dataset.ques_length
itow = dataset.itow
img_feat_size = 512

netE = _netE(opt.model, opt.ninp, opt.nhid, opt.nlayers, opt.dropout, img_feat_size)
netW = model._netW(vocab_size, opt.ninp, opt.dropout)
netD = model._netD(opt.model, opt.ninp, opt.nhid, opt.nlayers, vocab_size, opt.dropout)
critD =model.nPairLoss(opt.ninp, opt.margin)

if opt.model_path != '': # load the pre-trained model.
    netW.load_state_dict(checkpoint['netW'])
    netE.load_state_dict(checkpoint['netE'])
    netD.load_state_dict(checkpoint['netD'])

if opt.cuda: # ship to cuda, if has GPU
    netW.cuda(), netE.cuda(),
    netD.cuda(), critD.cuda()

####################################################################################
# training model
####################################################################################
def train(epoch):
    netW.train()
Пример #5
0
####################################################################################
# Build the Model
####################################################################################

n_words = dataset_val.vocab_size
ques_length = dataset_val.ques_length
ans_length = dataset_val.ans_length + 1
his_length = ques_length+dataset_val.ans_length
itow = dataset_val.itow
img_feat_size = 512

netE = _netE(opt.model, opt.ninp, opt.nhid, opt.nlayers, opt.dropout, img_feat_size)

netW = model._netW(n_words, opt.ninp, opt.dropout)
netD = model._netD(opt.model, opt.ninp, opt.nhid, opt.nlayers, n_words, opt.dropout)
critD = model.nPairLoss(opt.nhid, 2)

if opt.model_path != '': # load the pre-trained model.
    netW.load_state_dict(checkpoint['netW'])
    netE.load_state_dict(checkpoint['netE'])
    netD.load_state_dict(checkpoint['netD'])
    print('Loading model Success!')

if opt.cuda: # ship to cuda, if has GPU
    netW.cuda(), netE.cuda(), netD.cuda()
    critD.cuda()

n_neg = 100
####################################################################################
# Some Functions
####################################################################################
Пример #6
0
####################################################################################
# Build the Model
####################################################################################
n_neg = opt.negative_sample
vocab_size = dataset.vocab_size
ques_length = dataset.ques_length
ans_length = dataset.ans_length + 1
his_length = dataset.ans_length + dataset.ques_length
itow = dataset.itow
img_feat_size = 512

netE = _netE(opt.model, opt.ninp, opt.nhid, opt.nlayers, opt.dropout, img_feat_size)
netW = model._netW(vocab_size, opt.ninp, opt.dropout)
netD = model._netD(opt.model, opt.ninp, opt.nhid, opt.nlayers, vocab_size, opt.dropout)
critD =model.nPairLoss(opt.ninp, opt.margin, opt.alpha_norm, opt.sigma, opt.alphaC, opt.alphaE, opt.alphaN, opt.debug, opt.log_interval, opt.contra_thresh)

if opt.model_path != '': # load the pre-trained model.
    netW.load_state_dict(checkpoint['netW'])
    netE.load_state_dict(checkpoint['netE'])
    netD.load_state_dict(checkpoint['netD'])

if opt.cuda: # ship to cuda, if has GPU
    netW.cuda(), netE.cuda(),
    netD.cuda(), critD.cuda()

####################################################################################
# training model
####################################################################################
def train(epoch):
    netW.train()
def bulid_model_D(opt, src_vocab_size, tgt_vocab_size):
    save_path = opt.save_model
    print ("save_path", save_path)
    try:
        os.makedirs(save_path)
    except OSError:
        pass

    # if opt.train_from != '':
        # print("=> loading checkpoint '{}'".format(opt.train_from))
        # checkpoint = torch.load(opt.train_from)
        # model_path = opt.train_from
    if opt.model_path_D != '' :
        print("=> loading checkpoint '{}'".format(opt.model_path_D))
        checkpoint_D = torch.load(opt.model_path_D)
    else:
        model_path_D = save_path


    # if opt.model_path_G != '':
    #     print("=> loading checkpoint '{}'".format(opt.model_path_G))
    #     checkpoint_G = torch.load(opt.model_path_G)

    lr = opt.LM_lr
    print("rnn size", opt.rnn_size)
    src_netE_att = _netE_att(opt.rnn_size, src_vocab_size, opt.dropout)
    src_netW = _netW(src_vocab_size, opt.rnn_size, opt.dropout, name="src", cuda=opt.cuda)
    print("src_netW", src_vocab_size, src_netW.word_embed)
    tgt_netW = _netW(tgt_vocab_size, opt.rnn_size, opt.dropout, name="tgt", cuda=opt.cuda)
    print("tgt_netW", tgt_vocab_size, tgt_netW.word_embed)
    tgt_netE_att = _netE_att(opt.rnn_size, tgt_vocab_size, opt.dropout, is_target=True)

    critD = model.nPairLoss(opt.rnn_size, opt.margin)

    if opt.model_path_D != '' : # load the pre-trained model.
        src_netW.load_state_dict(checkpoint_D['src_netW'])
        tgt_netW.load_state_dict(checkpoint_D['tgt_netW'])
        src_netE_att.load_state_dict(checkpoint_D['src_netE_att'])
        tgt_netE_att.load_state_dict(checkpoint_D['tgt_netE_att'])
        lr = checkpoint_D['lr']

    if opt.cuda:
        tgt_netW.cuda()
        src_netW.cuda()
        src_netE_att.cuda()
        tgt_netE_att.cuda()
        critD.cuda()

    print('init Generative model...')

    sampler = model.gumbel_sampler()
    critG = model.G_loss(opt.rnn_size)
    critLM = model.LMCriterion()
    BLEU_score = model.BLEU_score()

    if opt.cuda: # ship to cuda, if has GPU
        sampler.cuda()
        critG.cuda()
        critLM.cuda()
        BLEU_score.cuda()

    print("load netD successfully")
    return tgt_netW, src_netW, src_netE_att, tgt_netE_att, critD, lr, sampler, critG, critLM, BLEU_score