Esempio n. 1
0
def train():
    """ 训练 """
    resume = TRAIN_HYPER_PARAMS["resume"]
    num_epochs = TRAIN_HYPER_PARAMS["num_epochs"]
    keep_prob = TRAIN_HYPER_PARAMS["keep_prob"]
    class_per_batch = TRAIN_HYPER_PARAMS["class_per_batch"]
    shoe_per_class = TRAIN_HYPER_PARAMS["shoe_per_class"]
    img_per_shoe = TRAIN_HYPER_PARAMS["img_per_shoe"]
    save_step = TRAIN_HYPER_PARAMS["save_step"]
    test_step = TRAIN_HYPER_PARAMS["test_step"]
    train_test = TRAIN_HYPER_PARAMS["train_test"]
    dev_test = TRAIN_HYPER_PARAMS["dev_test"]
    max_mini_batch_size = class_per_batch * \
        shoe_per_class * (shoe_per_class-1) / 2

    # GPU Config
    config = tf.ConfigProto()
    if GPU.enable:
        config.gpu_options.per_process_gpu_memory_fraction = GPU.memory_fraction
        config.gpu_options.allow_growth = True
        os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join(
            map(lambda x: str(x), GPU.devices))
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    model = Model(TRAIN_HYPER_PARAMS)
    recorder = Recorder(RECORDER_PATH, resume=resume)
    recorder.upload_params(TRAIN_HYPER_PARAMS)

    # train data
    data_set = data_import(augment=ALL)
    img_arrays = data_set["img_arrays"]
    indices = data_set["indices"]
    train_size = len(indices)

    # test data
    if train_test or dev_test:
        test_img_arrays, test_data_map, _ = test_data_import(
            augment=[TRANSPOSE], action_type="train")
        train_scope_length = len(test_data_map["train"][0]["scope_indices"])
        train_num_augment = len(test_data_map["train"][0]["indices"])
        dev_scope_length = len(test_data_map["dev"][0]["scope_indices"])
        dev_num_augment = len(test_data_map["dev"][0]["indices"])

    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(SEED)
        if resume:
            model.import_meta_graph()
            model.get_ops_from_graph(graph)
        else:
            model.init_ops()

        # test 计算图
        if train_test or dev_test:
            test_embeddings_length = len(test_img_arrays)
        if train_test:
            model.init_test_ops("train", train_scope_length, train_num_augment,
                                test_embeddings_length)
        if dev_test:
            model.init_test_ops("dev", dev_scope_length, dev_num_augment,
                                test_embeddings_length)

        with tf.Session(graph=graph, config=config) as sess:
            if resume:
                model.load(sess)
                print("成功恢复模型 {}".format(model.name))
            else:
                model.init_saver()
                sess.run(tf.global_variables_initializer())

            model.update_learning_rate(sess)

            clock = time.time()
            for epoch in range(recorder.checkpoint + 1, num_epochs + 1):
                recorder.update_checkpoint(epoch)
                # train
                train_costs = []
                triplet_cache = []
                for batch_index, triplets in BatchLoader(
                        model,
                        indices,
                        class_per_batch=class_per_batch,
                        shoe_per_class=shoe_per_class,
                        img_per_shoe=img_per_shoe,
                        img_arrays=img_arrays,
                        sess=sess):

                    # 小数据 cache 机制
                    if len(triplets) == 0:
                        continue
                    elif len(triplets) + len(
                            triplet_cache) <= max_mini_batch_size // 2:
                        triplet_cache.extend(triplets)
                        continue
                    elif max_mini_batch_size // 2 < len(triplets) + len(
                            triplet_cache) <= max_mini_batch_size:
                        triplets.extend(triplet_cache)
                        triplet_cache.clear()

                    triplet_list = [list(line) for line in zip(*triplets)]
                    mini_batch_size = len(triplet_list[0])

                    _, temp_cost = sess.run(
                        [model.ops["train_step"], model.ops["loss"]],
                        feed_dict={
                            model.ops["A"]:
                            np.divide(img_arrays[triplet_list[0]],
                                      127.5,
                                      dtype=np.float32) - 1,
                            model.ops["P"]:
                            np.divide(img_arrays[triplet_list[1]],
                                      127.5,
                                      dtype=np.float32) - 1,
                            model.ops["N"]:
                            np.divide(img_arrays[triplet_list[2]],
                                      127.5,
                                      dtype=np.float32) - 1,
                            model.ops["is_training"]:
                            True,
                            model.ops["keep_prob"]:
                            keep_prob
                        })
                    temp_cost /= max_mini_batch_size
                    print("{} mini-batch > {}/{} size: {} cost: {} ".format(
                        epoch, batch_index, train_size, mini_batch_size,
                        temp_cost),
                          end="\r")
                    train_costs.append(temp_cost)
                train_cost = sum(train_costs) / len(train_costs)

                # test
                log_str = "{}/{} {} train cost is {}".format(
                    epoch, num_epochs,
                    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                    train_cost)

                if epoch % test_step == 0:
                    train_top_1_acc, train_top_5_acc, dev_top_1_acc, dev_top_5_acc = "", "", "", ""
                    if train_test or dev_test:
                        test_embeddings = model.compute_embeddings(
                            test_img_arrays, sess=sess)

                    if train_test:
                        _, train_top_1_acc, train_top_5_acc = data_test(
                            test_data_map,
                            "train",
                            test_embeddings,
                            sess,
                            model,
                            log=False)
                        log_str += " train top-1:{:.2%} top-5:{:.2%}".format(
                            train_top_1_acc, train_top_5_acc)
                    if dev_test:
                        _, dev_top_1_acc, dev_top_5_acc = data_test(
                            test_data_map,
                            "dev",
                            test_embeddings,
                            sess,
                            model,
                            log=False)
                        log_str += " dev top-1:{:.2%} top-5:{:.2%}".format(
                            dev_top_1_acc, dev_top_5_acc)

                    # 预计完成时间
                    prec_time_stamp = (time.time() - clock) * \
                        ((num_epochs - epoch) // test_step) + clock
                    clock = time.time()
                    log_str += " >> {} ".format(
                        time.strftime("%Y-%m-%d %H:%M:%S",
                                      time.localtime(prec_time_stamp)))
                    recorder.record_item(epoch, [
                        train_top_1_acc, train_top_5_acc, dev_top_1_acc,
                        dev_top_5_acc
                    ])

                if epoch % save_step == 0:
                    model.save(sess)
                    recorder.save()

                print(log_str)
Esempio n. 2
0
def test():
    """ 测试主函数 """
    use_cache = TEST_PARAMS["use_cache"]

    # 加载模型与数据
    model = Model(TRAIN_HYPER_PARAMS)
    sample_cacher = Cacher(SAMPLE_EMB_CACHE)
    img_arrays, test_data_map, sample_length = test_data_import(
        augment=[(TRANSPOSE, BILATERAL_BLUR)], action_type="test")
    GLOBAL["img_arrays"] = img_arrays

    scope_length = len(test_data_map["test"][0]["scope_indices"])
    num_augment = len(test_data_map["test"][0]["indices"])

    # GPU Config
    config = tf.ConfigProto()
    if GPU.enable:
        config.gpu_options.per_process_gpu_memory_fraction = GPU.memory_fraction
        config.gpu_options.allow_growth = True
        os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join(
            map(lambda x: str(x), GPU.devices))
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    # 启动 TF 计算图与会话
    graph = tf.Graph()
    with graph.as_default():
        model.import_meta_graph()
        with tf.Session(graph=graph, config=config) as sess:
            model.load(sess)

            clock = time.time()
            model.get_ops_from_graph(graph)

            # 计算嵌入
            if use_cache and os.path.exists(SAMPLE_EMB_CACHE):
                # 如果已经有编码过的样本嵌入则直接读取
                sample_embs = sample_cacher.read()
                shoeprint_embs = model.compute_embeddings(
                    img_arrays[sample_length:], sess=sess)
                embeddings = np.concatenate((sample_embs, shoeprint_embs))
                print("成功读取预编码模板")
            else:
                embeddings = model.compute_embeddings(img_arrays, sess=sess)
                sample_embs = embeddings[:sample_length]
                sample_cacher.save(sample_embs)

            # 初始化测试计算图
            embeddings_length = len(img_arrays)
            model.init_test_ops("test", scope_length, num_augment,
                                embeddings_length)

            # 测试数据
            res, top_1_accuracy, top_5_accuracy = data_test(test_data_map,
                                                            "test",
                                                            embeddings,
                                                            sess,
                                                            model,
                                                            log=IS_LOG,
                                                            plot=IS_PLOT)
            print("Top-1: {:.2%}, Top-5: {:.2%}".format(
                top_1_accuracy, top_5_accuracy))
            print("{:.2f}s".format(time.time() - clock))

    # 将结果写入输出文件
    with open(RESULT_FILE, "w") as f:
        for i, item in enumerate(res):
            f.write(test_data_map["test"][i]["name"])
            for dist in item:
                f.write(SEP)
                f.write(str(1 / dist))
            f.write("\n")