예제 #1
0
            learning_rate = 0.0001
            batch_size = 1000
            n_modalities = 5
            size_modalities = [4, 4, 4, 4, 1, 1, 1, 1, 4, 4]
            numModels = [10, 10, 10, 10, 2, 2, 2, 2, 10, 10]
            numFactors = [10, 10, 10, 10, 2, 2, 2, 2, 10, 10]
            used_modalities = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
            corruption_level = 0.0
            softmaxnoise = 0.0
            numClass = 1
            numParam = 100
            vanilla = True

            network = arch(n_modalities, size_modalities, numModels,
                           numFactors, numClass, numParam, used_modalities,
                           batch_size, learning_rate, vanilla,
                           corruption_level, softmaxnoise)

            new_saver = tf.train.Saver()
            new_saver.restore(sess, "./models/droniou_complete.ckpt")
            print("Model restored.")

            if (test_id == 1):
                test_1()
            if (test_id == 2):
                test_2()
            if (test_id == 3):
                test_3()
            if (test_id == 4):
                test_4()
예제 #2
0
파일: test_droniou.py 프로젝트: Dekelv/VAE
with tf.Graph().as_default() as g:
	with tf.Session() as sess:
                
		# Network parameters

                learning_rate = 0.0001
                batch_size = 1000
                n_modalities=5
                size_modalities=[8,8,2,2,8]
                numModels=[20,20,5,5,20]
                numFactors=[20,20,5,5,20]
                used_modalities = np.zeros(n_modalities)
                #used_modalities =  
                numClass=1
                numParam=100
                network = arch( n_modalities, size_modalities, numModels, numFactors, numClass, numParam, used_modalities, batch_size,learning_rate)
                
	with tf.Session() as sess:
		new_saver = tf.train.Saver() 
                new_saver.restore(sess, "./models_vanilla/droniou_complete.ckpt")
		print("Model restored.")
						


		# Test 1: complete data
		print('Test 1')
		sample_init = 100
		x_sample = X_augm_test[sample_init:sample_init+batch_size,:28]  
		x_reconstruct = network.reconstruct(sess, x_sample)
		scipy.io.savemat("results/mvae_final_test1.mat",{"x_reconstruct":x_reconstruct,"x_sample":x_sample})