def test(self): all_ok = 0 # test_epoch = self._data.test_batch_number test_epoch = 100 keep_prob = 0 for now in range(test_epoch): images, labels = self._data.next_test(now) prediction, keep_prob = self._sess.run( [self._prediction, self._keep_prob], feed_dict={ self._images: images, self._keep_prob: 1.0 }) all_ok += np.sum(np.equal(labels, prediction)) all_number = test_epoch * self._batch_size Tools.print_info("the result is {} ({}/{} keep_prob {})".format( all_ok / (all_number * 1.0), all_ok, all_number, keep_prob)) pass
def __init__(self, zip_file, ratio=4): data_path = zip_file.split(".zip")[0] self.train_path = os.path.join(data_path, "train") self.test_path = os.path.join(data_path, "test") if not os.path.exists(data_path): f = zipfile.ZipFile(zip_file, "r") f.extractall(data_path) all_image = self.get_all_images( os.path.join(data_path, data_path.split("/")[-1])) self.get_data_result(all_image, ratio, Tools.new_dir(self.train_path), Tools.new_dir(self.test_path)) else: Tools.print_info("data is exists") pass
def get_data_result(all_image, ratio, train_path, test_path): train_list = [] test_list = [] # 遍历 Tools.print_info("bian") for now_type in range(len(all_image)): now_images = all_image[now_type] for now_image in now_images: # 划分 if np.random.randint(0, ratio) == 0: # 测试数据 test_list.append((now_type, now_image)) else: train_list.append((now_type, now_image)) pass # 打乱 Tools.print_info("shuffle") np.random.shuffle(train_list) np.random.shuffle(test_list) # 提取训练图片和标签 Tools.print_info("train") for index in range(len(train_list)): now_type, image = train_list[index] shutil.copyfile( image, os.path.join( train_path, str(np.random.randint(0, 1000000)) + "-" + str(now_type) + ".jpg")) # 提取测试图片和标签 Tools.print_info("test") for index in range(len(test_list)): now_type, image = test_list[index] shutil.copyfile( image, os.path.join( test_path, str(np.random.randint(0, 1000000)) + "-" + str(now_type) + ".jpg")) pass
def train(self, epochs, save_model, min_loss, print_loss, test, save, keep_prob=0.5): self._sess.run(tf.global_variables_initializer()) # 如果没有训练, 就先加载 # if os.path.isfile(save_model): # self._saver.restore(self._sess, save_model) epoch = 0 for epoch in range(epochs): images, labels = self._data.next_train() loss, soft_max, _, learning_rate = self._sess.run( fetches=[ self._loss, self._softmax, self._solver, self._learning_rate ], feed_dict={ self._images: images, self._labels: labels, self._keep_prob: keep_prob }) if epoch % print_loss == 0: Tools.print_info("{} learning_rate:{} loss {}".format( epoch, learning_rate, loss)) # if loss < min_loss: # break if epoch % test == 0 and epoch != 0: self.test() pass if epoch % save == 0 and epoch != 0: self._saver.save(self._sess, save_path=save_model) pass Tools.print_info("{}: train end".format(epoch)) self.test() Tools.print_info("test end") pass
type=int, default=1000, help="decay_steps") parser.add_argument("-skip_layers", type=str, default=[], help="finetune skip these layers") parser.add_argument("-zip_file", type=str, default="../data/resisc45.zip", help="zip file path") args = parser.parse_args() output_param = "name={},epochs={},batch_size={},type_number={},image_size={},image_channel={},zip_file={},keep_prob={}" Tools.print_info( output_param.format(args.name, args.epochs, args.batch_size, args.type_number, args.image_size, args.image_channel, args.zip_file, args.keep_prob)) # now_train_path, now_test_path = PreData.main(zip_file=args.zip_file) # now_data = Data(batch_size=args.batch_size, type_number=args.type_number, image_size=args.image_size, # image_channel=args.image_channel, train_path=now_train_path, test_path=now_test_path) now_data = Cifar10Data(batch_size=args.batch_size, type_number=args.type_number, image_size=args.image_size, image_channel=args.image_channel) now_net = VGGNet(now_data.type_number, now_data.image_size, now_data.image_channel, now_data.batch_size) runner = Runner(data=now_data,