def train(): """ 训练 """ resume = TRAIN_HYPER_PARAMS["resume"] num_epochs = TRAIN_HYPER_PARAMS["num_epochs"] keep_prob = TRAIN_HYPER_PARAMS["keep_prob"] class_per_batch = TRAIN_HYPER_PARAMS["class_per_batch"] shoe_per_class = TRAIN_HYPER_PARAMS["shoe_per_class"] img_per_shoe = TRAIN_HYPER_PARAMS["img_per_shoe"] save_step = TRAIN_HYPER_PARAMS["save_step"] test_step = TRAIN_HYPER_PARAMS["test_step"] train_test = TRAIN_HYPER_PARAMS["train_test"] dev_test = TRAIN_HYPER_PARAMS["dev_test"] max_mini_batch_size = class_per_batch * \ shoe_per_class * (shoe_per_class-1) / 2 # GPU Config config = tf.ConfigProto() if GPU.enable: config.gpu_options.per_process_gpu_memory_fraction = GPU.memory_fraction config.gpu_options.allow_growth = True os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join( map(lambda x: str(x), GPU.devices)) else: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" model = Model(TRAIN_HYPER_PARAMS) recorder = Recorder(RECORDER_PATH, resume=resume) recorder.upload_params(TRAIN_HYPER_PARAMS) # train data data_set = data_import(augment=ALL) img_arrays = data_set["img_arrays"] indices = data_set["indices"] train_size = len(indices) # test data if train_test or dev_test: test_img_arrays, test_data_map, _ = test_data_import( augment=[TRANSPOSE], action_type="train") train_scope_length = len(test_data_map["train"][0]["scope_indices"]) train_num_augment = len(test_data_map["train"][0]["indices"]) dev_scope_length = len(test_data_map["dev"][0]["scope_indices"]) dev_num_augment = len(test_data_map["dev"][0]["indices"]) graph = tf.Graph() with graph.as_default(): tf.set_random_seed(SEED) if resume: model.import_meta_graph() model.get_ops_from_graph(graph) else: model.init_ops() # test 计算图 if train_test or dev_test: test_embeddings_length = len(test_img_arrays) if train_test: model.init_test_ops("train", train_scope_length, train_num_augment, test_embeddings_length) if dev_test: model.init_test_ops("dev", dev_scope_length, dev_num_augment, test_embeddings_length) with tf.Session(graph=graph, config=config) as sess: if resume: model.load(sess) print("成功恢复模型 {}".format(model.name)) else: model.init_saver() sess.run(tf.global_variables_initializer()) model.update_learning_rate(sess) clock = time.time() for epoch in range(recorder.checkpoint + 1, num_epochs + 1): recorder.update_checkpoint(epoch) # train train_costs = [] triplet_cache = [] for batch_index, triplets in BatchLoader( model, indices, class_per_batch=class_per_batch, shoe_per_class=shoe_per_class, img_per_shoe=img_per_shoe, img_arrays=img_arrays, sess=sess): # 小数据 cache 机制 if len(triplets) == 0: continue elif len(triplets) + len( triplet_cache) <= max_mini_batch_size // 2: triplet_cache.extend(triplets) continue elif max_mini_batch_size // 2 < len(triplets) + len( triplet_cache) <= max_mini_batch_size: triplets.extend(triplet_cache) triplet_cache.clear() triplet_list = [list(line) for line in zip(*triplets)] mini_batch_size = len(triplet_list[0]) _, temp_cost = sess.run( [model.ops["train_step"], model.ops["loss"]], feed_dict={ model.ops["A"]: np.divide(img_arrays[triplet_list[0]], 127.5, dtype=np.float32) - 1, model.ops["P"]: np.divide(img_arrays[triplet_list[1]], 127.5, dtype=np.float32) - 1, model.ops["N"]: np.divide(img_arrays[triplet_list[2]], 127.5, dtype=np.float32) - 1, model.ops["is_training"]: True, model.ops["keep_prob"]: keep_prob }) temp_cost /= max_mini_batch_size print("{} mini-batch > {}/{} size: {} cost: {} ".format( epoch, batch_index, train_size, mini_batch_size, temp_cost), end="\r") train_costs.append(temp_cost) train_cost = sum(train_costs) / len(train_costs) # test log_str = "{}/{} {} train cost is {}".format( epoch, num_epochs, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), train_cost) if epoch % test_step == 0: train_top_1_acc, train_top_5_acc, dev_top_1_acc, dev_top_5_acc = "", "", "", "" if train_test or dev_test: test_embeddings = model.compute_embeddings( test_img_arrays, sess=sess) if train_test: _, train_top_1_acc, train_top_5_acc = data_test( test_data_map, "train", test_embeddings, sess, model, log=False) log_str += " train top-1:{:.2%} top-5:{:.2%}".format( train_top_1_acc, train_top_5_acc) if dev_test: _, dev_top_1_acc, dev_top_5_acc = data_test( test_data_map, "dev", test_embeddings, sess, model, log=False) log_str += " dev top-1:{:.2%} top-5:{:.2%}".format( dev_top_1_acc, dev_top_5_acc) # 预计完成时间 prec_time_stamp = (time.time() - clock) * \ ((num_epochs - epoch) // test_step) + clock clock = time.time() log_str += " >> {} ".format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(prec_time_stamp))) recorder.record_item(epoch, [ train_top_1_acc, train_top_5_acc, dev_top_1_acc, dev_top_5_acc ]) if epoch % save_step == 0: model.save(sess) recorder.save() print(log_str)
def test(): """ 测试主函数 """ use_cache = TEST_PARAMS["use_cache"] # 加载模型与数据 model = Model(TRAIN_HYPER_PARAMS) sample_cacher = Cacher(SAMPLE_EMB_CACHE) img_arrays, test_data_map, sample_length = test_data_import( augment=[(TRANSPOSE, BILATERAL_BLUR)], action_type="test") GLOBAL["img_arrays"] = img_arrays scope_length = len(test_data_map["test"][0]["scope_indices"]) num_augment = len(test_data_map["test"][0]["indices"]) # GPU Config config = tf.ConfigProto() if GPU.enable: config.gpu_options.per_process_gpu_memory_fraction = GPU.memory_fraction config.gpu_options.allow_growth = True os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join( map(lambda x: str(x), GPU.devices)) else: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 启动 TF 计算图与会话 graph = tf.Graph() with graph.as_default(): model.import_meta_graph() with tf.Session(graph=graph, config=config) as sess: model.load(sess) clock = time.time() model.get_ops_from_graph(graph) # 计算嵌入 if use_cache and os.path.exists(SAMPLE_EMB_CACHE): # 如果已经有编码过的样本嵌入则直接读取 sample_embs = sample_cacher.read() shoeprint_embs = model.compute_embeddings( img_arrays[sample_length:], sess=sess) embeddings = np.concatenate((sample_embs, shoeprint_embs)) print("成功读取预编码模板") else: embeddings = model.compute_embeddings(img_arrays, sess=sess) sample_embs = embeddings[:sample_length] sample_cacher.save(sample_embs) # 初始化测试计算图 embeddings_length = len(img_arrays) model.init_test_ops("test", scope_length, num_augment, embeddings_length) # 测试数据 res, top_1_accuracy, top_5_accuracy = data_test(test_data_map, "test", embeddings, sess, model, log=IS_LOG, plot=IS_PLOT) print("Top-1: {:.2%}, Top-5: {:.2%}".format( top_1_accuracy, top_5_accuracy)) print("{:.2f}s".format(time.time() - clock)) # 将结果写入输出文件 with open(RESULT_FILE, "w") as f: for i, item in enumerate(res): f.write(test_data_map["test"][i]["name"]) for dist in item: f.write(SEP) f.write(str(1 / dist)) f.write("\n")