Beispiel #1
0
def train_graph():
    input_hyper = init_hyper()
    input_pls = {
        'feature':
        tf.placeholder(dtype=tf.float32,
                       shape=[None, input_hyper['hidden_dim']])
    }

    vnet = VTranse()
    vnet.create_graph(input_hyper['N_each_batch'], input_hyper['index_sp'],
                      input_hyper['index_cls'], input_hyper['N_cls'],
                      input_hyper['N_rela'])

    graph_visual = Graph(input_hyper['layer_num'], input_hyper['hidden_dim'],
                         input_hyper['num_cls'], input_pls,
                         input_hyper['knowledge'])

    graph_text = Graph(input_hyper['layer_num'], input_hyper['hidden_dim'],
                       input_hyper['num_cls'], input_pls,
                       input_hyper['knowledge'])

    text_layer_out = graph_text.get_layer_out()
    optimizer = tf.train.AdamOptimizer(learning_rate=input_hyper['lr_rate'])
    train_var = tf.trainable_variables()
    restore_var = [
        var for var in train_var if 'vgg_16' in var.name or 'RD' in var.name
    ]
    saver_res = tf.train.Saver(restore_var)

    with tf.Session() as sess:
        # init
        init = tf.global_variables_initializer()
        sess.run(init)
        saver_res.restore(sess, input_hyper['model_path'])
        roidb_read = read_roidb(input_hyper['roidb_path'])
        train_roidb = roidb_read['train_roidb']
        test_roidb = roidb_read['test_roidb']
        N_train = len(train_roidb)
        N_test = len(test_roidb)

        for epoch in range(input_hyper['num_epoch']):
            for roidb_id in range(N_train):
                roidb_use = train_roidb[roidb_id]
            if len(roidb_use['rela_gt']) == 0:
                continue
            rd_loss_temp, acc_temp, diff = vnet.train_predicate(
                sess, roidb_use, None)
            diff = np.array(diff)
            print(diff.shape)
            vf = []
            print(np.array(input_hyper['rel_emb']).shape)
            num_batch = diff.shape[0]
            # for i in range(num_batch):
            #     num_nodes = diff[i][0]
            #
            #     vf.append(visual_feature)
            feed_dict = {}
            feed_dict.update({input_pls['feature']: input_hyper['rel_emb']})
            text_out = sess.run(text_layer_out, feed_dict=feed_dict)
            print(text_out)
Beispiel #2
0
sys.path.append("models/ImgVRD/V-Trans-E")

from model.config import cfg
from model.ass_fun import *
from net.vtranse_vgg import VTranse

N_cls = cfg.VG_NUM_CLASS
N_rela = cfg.VG_NUM_RELA
N_each_batch = cfg.VG_BATCH_NUM_RELA

index_sp = False
index_cls = False

vnet = VTranse()
vnet.create_graph(N_each_batch, index_sp, index_cls, N_cls, N_rela)

roidb_path = cfg.DIR + 'vtranse/input/vg_rela_roidb.npz'
model_path = cfg.DIR + 'vtranse/pred_para/vg_vgg_rela/vg_vgg0001.ckpt'
save_path = cfg.DIR + 'vtranse/pred_res/vg_rela_roidb.npz'

roidb_read = read_roidb(roidb_path)
train_roidb = roidb_read['train_roidb']
test_roidb = roidb_read['test_roidb']
N_train = len(train_roidb)
N_test = len(test_roidb)
print('data loaded')

saver = tf.train.Saver()

with tf.Session() as sess: