def train_graph(): input_hyper = init_hyper() input_pls = { 'feature': tf.placeholder(dtype=tf.float32, shape=[None, input_hyper['hidden_dim']]) } vnet = VTranse() vnet.create_graph(input_hyper['N_each_batch'], input_hyper['index_sp'], input_hyper['index_cls'], input_hyper['N_cls'], input_hyper['N_rela']) graph_visual = Graph(input_hyper['layer_num'], input_hyper['hidden_dim'], input_hyper['num_cls'], input_pls, input_hyper['knowledge']) graph_text = Graph(input_hyper['layer_num'], input_hyper['hidden_dim'], input_hyper['num_cls'], input_pls, input_hyper['knowledge']) text_layer_out = graph_text.get_layer_out() optimizer = tf.train.AdamOptimizer(learning_rate=input_hyper['lr_rate']) train_var = tf.trainable_variables() restore_var = [ var for var in train_var if 'vgg_16' in var.name or 'RD' in var.name ] saver_res = tf.train.Saver(restore_var) with tf.Session() as sess: # init init = tf.global_variables_initializer() sess.run(init) saver_res.restore(sess, input_hyper['model_path']) roidb_read = read_roidb(input_hyper['roidb_path']) train_roidb = roidb_read['train_roidb'] test_roidb = roidb_read['test_roidb'] N_train = len(train_roidb) N_test = len(test_roidb) for epoch in range(input_hyper['num_epoch']): for roidb_id in range(N_train): roidb_use = train_roidb[roidb_id] if len(roidb_use['rela_gt']) == 0: continue rd_loss_temp, acc_temp, diff = vnet.train_predicate( sess, roidb_use, None) diff = np.array(diff) print(diff.shape) vf = [] print(np.array(input_hyper['rel_emb']).shape) num_batch = diff.shape[0] # for i in range(num_batch): # num_nodes = diff[i][0] # # vf.append(visual_feature) feed_dict = {} feed_dict.update({input_pls['feature']: input_hyper['rel_emb']}) text_out = sess.run(text_layer_out, feed_dict=feed_dict) print(text_out)
tf_config.gpu_options.allow_growth = True with tf.Session(config=tf_config) as sess: init = tf.global_variables_initializer() sess.run(init) saver_res.restore(sess, res_path) t = 0.0 rd_loss = 0.0 acc = 0.0 for r in range(N_round): for roidb_id in range(N_train): roidb_use = train_roidb[roidb_id] if len(roidb_use['rela_gt']) == 0: continue rd_loss_temp, acc_temp = vnet.train_predicate( sess, roidb_use, RD_train) rd_loss = rd_loss + rd_loss_temp acc = acc + acc_temp t = t + 1.0 if t % N_show == 0: print("t: {0}, rd_loss: {1}, acc: {2}".format( t, rd_loss / N_show, acc / N_show)) rd_loss = 0.0 acc = 0.0 if t % N_save == 0: save_path = cfg.DIR + 'vtranse/pred_para/vg_vgg/vg_vgg' + format( int(t / N_save), '04') + '.ckpt' print("saving model to {0}".format(save_path)) saver.save(sess, save_path) rd_loss_val = 0.0 acc_val = 0.0