def train(): """ 训练 """ resume = TRAIN_HYPER_PARAMS["resume"] num_epochs = TRAIN_HYPER_PARAMS["num_epochs"] keep_prob = TRAIN_HYPER_PARAMS["keep_prob"] class_per_batch = TRAIN_HYPER_PARAMS["class_per_batch"] shoe_per_class = TRAIN_HYPER_PARAMS["shoe_per_class"] img_per_shoe = TRAIN_HYPER_PARAMS["img_per_shoe"] save_step = TRAIN_HYPER_PARAMS["save_step"] test_step = TRAIN_HYPER_PARAMS["test_step"] train_test = TRAIN_HYPER_PARAMS["train_test"] dev_test = TRAIN_HYPER_PARAMS["dev_test"] max_mini_batch_size = class_per_batch * \ shoe_per_class * (shoe_per_class-1) / 2 # GPU Config config = tf.ConfigProto() if GPU.enable: config.gpu_options.per_process_gpu_memory_fraction = GPU.memory_fraction config.gpu_options.allow_growth = True os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join( map(lambda x: str(x), GPU.devices)) else: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" model = Model(TRAIN_HYPER_PARAMS) recorder = Recorder(RECORDER_PATH, resume=resume) recorder.upload_params(TRAIN_HYPER_PARAMS) # train data data_set = data_import(augment=ALL) img_arrays = data_set["img_arrays"] indices = data_set["indices"] train_size = len(indices) # test data if train_test or dev_test: test_img_arrays, test_data_map, _ = test_data_import( augment=[TRANSPOSE], action_type="train") train_scope_length = len(test_data_map["train"][0]["scope_indices"]) train_num_augment = len(test_data_map["train"][0]["indices"]) dev_scope_length = len(test_data_map["dev"][0]["scope_indices"]) dev_num_augment = len(test_data_map["dev"][0]["indices"]) graph = tf.Graph() with graph.as_default(): tf.set_random_seed(SEED) if resume: model.import_meta_graph() model.get_ops_from_graph(graph) else: model.init_ops() # test 计算图 if train_test or dev_test: test_embeddings_length = len(test_img_arrays) if train_test: model.init_test_ops("train", train_scope_length, train_num_augment, test_embeddings_length) if dev_test: model.init_test_ops("dev", dev_scope_length, dev_num_augment, test_embeddings_length) with tf.Session(graph=graph, config=config) as sess: if resume: model.load(sess) print("成功恢复模型 {}".format(model.name)) else: model.init_saver() sess.run(tf.global_variables_initializer()) model.update_learning_rate(sess) clock = time.time() for epoch in range(recorder.checkpoint + 1, num_epochs + 1): recorder.update_checkpoint(epoch) # train train_costs = [] triplet_cache = [] for batch_index, triplets in BatchLoader( model, indices, class_per_batch=class_per_batch, shoe_per_class=shoe_per_class, img_per_shoe=img_per_shoe, img_arrays=img_arrays, sess=sess): # 小数据 cache 机制 if len(triplets) == 0: continue elif len(triplets) + len( triplet_cache) <= max_mini_batch_size // 2: triplet_cache.extend(triplets) continue elif max_mini_batch_size // 2 < len(triplets) + len( triplet_cache) <= max_mini_batch_size: triplets.extend(triplet_cache) triplet_cache.clear() triplet_list = [list(line) for line in zip(*triplets)] mini_batch_size = len(triplet_list[0]) _, temp_cost = sess.run( [model.ops["train_step"], model.ops["loss"]], feed_dict={ model.ops["A"]: np.divide(img_arrays[triplet_list[0]], 127.5, dtype=np.float32) - 1, model.ops["P"]: np.divide(img_arrays[triplet_list[1]], 127.5, dtype=np.float32) - 1, model.ops["N"]: np.divide(img_arrays[triplet_list[2]], 127.5, dtype=np.float32) - 1, model.ops["is_training"]: True, model.ops["keep_prob"]: keep_prob }) temp_cost /= max_mini_batch_size print("{} mini-batch > {}/{} size: {} cost: {} ".format( epoch, batch_index, train_size, mini_batch_size, temp_cost), end="\r") train_costs.append(temp_cost) train_cost = sum(train_costs) / len(train_costs) # test log_str = "{}/{} {} train cost is {}".format( epoch, num_epochs, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), train_cost) if epoch % test_step == 0: train_top_1_acc, train_top_5_acc, dev_top_1_acc, dev_top_5_acc = "", "", "", "" if train_test or dev_test: test_embeddings = model.compute_embeddings( test_img_arrays, sess=sess) if train_test: _, train_top_1_acc, train_top_5_acc = data_test( test_data_map, "train", test_embeddings, sess, model, log=False) log_str += " train top-1:{:.2%} top-5:{:.2%}".format( train_top_1_acc, train_top_5_acc) if dev_test: _, dev_top_1_acc, dev_top_5_acc = data_test( test_data_map, "dev", test_embeddings, sess, model, log=False) log_str += " dev top-1:{:.2%} top-5:{:.2%}".format( dev_top_1_acc, dev_top_5_acc) # 预计完成时间 prec_time_stamp = (time.time() - clock) * \ ((num_epochs - epoch) // test_step) + clock clock = time.time() log_str += " >> {} ".format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(prec_time_stamp))) recorder.record_item(epoch, [ train_top_1_acc, train_top_5_acc, dev_top_1_acc, dev_top_5_acc ]) if epoch % save_step == 0: model.save(sess) recorder.save() print(log_str)