コード例 #1
0
sess = tf.Session( config=tf.ConfigProto(gpu_options=gpu_options ) )
sess = tf.InteractiveSession()

if net_name == '12net' :
	N12 = cvpr_network.cvpr_12net()
	N12.input_config_noscale()
	N12.infer()
	N12.objective()
	N12.train(LearningRate, threshold)
	network_list = [N12]

elif net_name == '24net' :
	N12 = cvpr_network.cvpr_12net()
	N12.input_config_div2()
	N12.infer()
	N24 = cvpr_network.cvpr_24net()
	N24.input_config_noscale()
	N24.infer()
	N24.objective()
	N24.train(LearningRate, threshold)
	network_list = [N12, N24]

elif net_name == '48net' :
	N12 = cvpr_network.cvpr_12net()
	N12.input_config_div4()
	N12.infer()
	N24 = cvpr_network.cvpr_24net()
	N24.input_config()
	N24.infer()
	N48 = cvpr_network.cvpr_48net()
	N48.infer()
コード例 #2
0
N12 = cvpr_network.cvpr_12net()
N12.input_config_noscale()
N12.infer()

CAL12 = cvpr_network.cvpr_calib_12net()
CAL12.input_config_noscale()
CAL12.infer()

saver1 = tf.train.Saver( { 'net12_w_conv1':N12.W_conv1, 'net12_b_conv1':N12.b_conv1, 'net12_w_fc1':N12.W_fc1, 'net12_b_fc1':N12.b_fc1, 'net12_w_fc2':N12.W_fc2, 'net12_b_fc2': N12.b_fc2 } )
saver1.restore(sess, ckpt_file1)
saver2 = tf.train.Saver( { 'cal12_w_conv1':CAL12.W_conv1, 'cal12_b_conv1':CAL12.b_conv1, 'cal12_w_fc1':CAL12.W_fc1, 'cal12_b_fc1':CAL12.b_fc1, 'cal12_w_fc2':CAL12.W_fc2, 'cal12_b_fc2': CAL12.b_fc2 } )
saver2.restore(sess, ckpt_file2)

if (net_name == '48net') :
	N24 = cvpr_network.cvpr_24net()
	N24.input_config_noscale()
	N24.infer()
	
	CAL24 = cvpr_network.cvpr_calib_24net()
	CAL24.input_config_noscale()
	CAL24.infer()

	saver3 = tf.train.Saver( { 'net24_w_conv1':N24.W_conv1, 'net24_b_conv1':N24.b_conv1, 'net24_w_fc1':N24.W_fc1, 'net24_b_fc1':N24.b_fc1, 'net24_w_fc2':N24.W_fc2, 'net24_b_fc2': N24.b_fc2 } )
	saver3.restore(sess, ckpt_file3)
	saver4 = tf.train.Saver( { 'cal24_w_conv1':CAL24.W_conv1, 'cal24_b_conv1':CAL24.b_conv1, 'cal24_w_fc1':CAL24.W_fc1, 'cal24_b_fc1':CAL24.b_fc1, 'cal24_w_fc2':CAL24.W_fc2, 'cal24_b_fc2': CAL24.b_fc2 } )
	saver4.restore(sess, ckpt_file4)

if net_name == '24net' :
	network_list = [N12, CAL12]
elif net_name == '48net' :