Esempio n. 1
0
 def test(self, backup_path='backup/cifar10-v16/', epoch=500, image_str=''):
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
     self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
     # 读取模型
     self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)
     model_path = os.path.join(backup_path, 'model_%d.ckpt' % (epoch))
     assert (os.path.exists(model_path + '.index'))
     self.saver.restore(self.sess, model_path)
     print('read model from %s' % (model_path))
     image = cv2.imread(image_str)
     image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_CUBIC)
     image = image.astype(float)
     images = []
     images.append(image)
     images = np.array(images, dtype='float')
     cifar10 = Corpus()
     test_images = cifar10.data_augmentation(images,
                                             flip=False,
                                             crop=True,
                                             crop_shape=(24, 24, 3),
                                             whiten=True,
                                             noise=False)
     test_labels = np.array(range(10))
     tst_list = []
     for i in range(10):
         # print("[+] ", test_labels[i])
         [avg_accuracy] = self.sess.run(fetches=[self.accuracy],
                                        feed_dict={
                                            self.images: test_images,
                                            self.labels: [test_labels[i]],
                                            self.keep_prob: 1.0
                                        })
         # print(avg_accuracy)
         tst_list.append(avg_accuracy)
         # print(image_str)
     # print(tst_list)
     self.sess.close()
     return tst_list
Esempio n. 2
0
# -*- coding: utf8 -*-
# author: ronniecao
import os
from src.data.cifar10 import Corpus

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

cifar10 = Corpus()

def basic_cnn():
    from src.model.basic_cnn import ConvNet
    convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/basic.yaml')
    # convnet.debug()
    convnet.train(dataloader=cifar10, backup_path='backups/cifar10-v1/', batch_size=128, n_epoch=500)
    # convnet.test(dataloader=cifar10, backup_path='backup/cifar10-v2/', epoch=5000, batch_size=128)
    # convnet.observe_salience(batch_size=1, n_channel=3, num_test=10, epoch=2)
    # convnet.observe_hidden_distribution(batch_size=128, n_channel=3, num_test=1, epoch=980)
    
def vgg_cnn():
    from src.model.basic_cnn import ConvNet
    convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/vgg.yaml')
    # convnet.debug()
    convnet.train(dataloader=cifar10, backup_path='backups/cifar10-v2/', batch_size=128, n_epoch=500)
    # convnet.test(backup_path='backup/cifar10-v3/', epoch=0, batch_size=128)
    # convnet.observe_salience(batch_size=1, n_channel=3, num_test=10, epoch=2)
    # convnet.observe_hidden_distribution(batch_size=128, n_channel=3, num_test=1, epoch=980)
    
def resnet():
    from src.model.resnet import ConvNet
    convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/resnet.yaml')