def test(self, backup_path='backup/cifar10-v16/', epoch=500, image_str=''): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25) self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) # 读取模型 self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2) model_path = os.path.join(backup_path, 'model_%d.ckpt' % (epoch)) assert (os.path.exists(model_path + '.index')) self.saver.restore(self.sess, model_path) print('read model from %s' % (model_path)) image = cv2.imread(image_str) image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_CUBIC) image = image.astype(float) images = [] images.append(image) images = np.array(images, dtype='float') cifar10 = Corpus() test_images = cifar10.data_augmentation(images, flip=False, crop=True, crop_shape=(24, 24, 3), whiten=True, noise=False) test_labels = np.array(range(10)) tst_list = [] for i in range(10): # print("[+] ", test_labels[i]) [avg_accuracy] = self.sess.run(fetches=[self.accuracy], feed_dict={ self.images: test_images, self.labels: [test_labels[i]], self.keep_prob: 1.0 }) # print(avg_accuracy) tst_list.append(avg_accuracy) # print(image_str) # print(tst_list) self.sess.close() return tst_list
# -*- coding: utf8 -*- # author: ronniecao import os from src.data.cifar10 import Corpus os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = '1' cifar10 = Corpus() def basic_cnn(): from src.model.basic_cnn import ConvNet convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/basic.yaml') # convnet.debug() convnet.train(dataloader=cifar10, backup_path='backups/cifar10-v1/', batch_size=128, n_epoch=500) # convnet.test(dataloader=cifar10, backup_path='backup/cifar10-v2/', epoch=5000, batch_size=128) # convnet.observe_salience(batch_size=1, n_channel=3, num_test=10, epoch=2) # convnet.observe_hidden_distribution(batch_size=128, n_channel=3, num_test=1, epoch=980) def vgg_cnn(): from src.model.basic_cnn import ConvNet convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/vgg.yaml') # convnet.debug() convnet.train(dataloader=cifar10, backup_path='backups/cifar10-v2/', batch_size=128, n_epoch=500) # convnet.test(backup_path='backup/cifar10-v3/', epoch=0, batch_size=128) # convnet.observe_salience(batch_size=1, n_channel=3, num_test=10, epoch=2) # convnet.observe_hidden_distribution(batch_size=128, n_channel=3, num_test=1, epoch=980) def resnet(): from src.model.resnet import ConvNet convnet = ConvNet(n_channel=3, n_classes=10, image_size=24, network_path='src/config/networks/resnet.yaml')