Exemplo n.º 1
0
def test(params):
    assert params["mode"].lower() == "test", "change training mode to 'test' or 'eval'"
    # assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params"

    print("Building the model ...")
    model = SequenceToSequence(params)

    print("Creating the vocab ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])

    print("Creating the batcher ...")
    b = batcher(vocab, params)

    print("Creating the checkpoint manager")
    checkpoint_dir = "{}/checkpoint".format(params["seq2seq_model_dir"])
    ckpt = tf.train.Checkpoint(SequenceToSequence=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

    # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint
    # path = ckpt_manager.latest_checkpoint
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Model restored")
    # for batch in b:
    #     yield batch_greedy_decode(model, batch, vocab, params)
    if params['greedy_decode']:
        # params['batch_size'] = 512
        predict_result(model, params, vocab, params['test_save_dir'])
Exemplo n.º 2
0
def train(params):
    assert params["mode"].lower() == "train", "change training mode to 'train'"

    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    print('true vocab is ', vocab)

    print("Creating the batcher ...")
    b = batcher(vocab, params)

    print("Building the model ...")
    model = SequenceToSequence(params)

    print("Creating the checkpoint manager")
    checkpoint_dir = "{}/checkpoint".format(params["seq2seq_model_dir"])
    ckpt = tf.train.Checkpoint(SequenceToSequence=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    print("Starting the training ...")
    train_model(model, b, params, ckpt, ckpt_manager)
Exemplo n.º 3
0
def train(params):
    assert params["mode"].lower() == "train", "change training mode to 'train'"
    # 对应文件vocab.txt vocab_size参数设置为30000
    # Vocab类定义在batcher下
    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    # print('true vocab is ', vocab) 注释,返回的是object类型
    print('true vocab is ', vocab.count)  # 为设定的30000

    print("Creating the batcher ...")
    b = batcher(vocab, params)
    # print(type(b))
    # <class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>

    print("Building the model ...")
    model = SequenceToSequence(params)

    print("Creating the checkpoint manager")
    checkpoint_dir = "{}/checkpoint_vocab30000".format(
        params["seq2seq_model_dir"])
    ckpt = tf.train.Checkpoint(SequenceToSequence=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_dir,
                                              max_to_keep=5)

    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    print("Starting the training ...")
    train_model(model, b, params, ckpt, ckpt_manager)
Exemplo n.º 4
0
def test(params):
    assert params["mode"].lower(
    ) == "test", "change training mode to 'test' or 'eval'"
    # assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params"

    print("Building the model ...")
    model = SequenceToSequence(params)

    print("Creating the vocab ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])

    # print("Creating the batcher ...")
    # b = batcher(vocab, params) 在predict_result执行的

    print("Creating the checkpoint manager")
    checkpoint_dir = "{}/checkpoint".format(params["seq2seq_model_dir"])
    ckpt = tf.train.Checkpoint(SequenceToSequence=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_dir,
                                              max_to_keep=5)

    # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint
    # path = ckpt_manager.latest_checkpoint
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Model restored")
    # for batch in b:
    #     yield batch_greedy_decode(model, batch, vocab, params)
    """    
    修改:
        去掉了predict_result 函数
        将处理steps_epoch的共用代码提取出来,再进行分支greedy_decode/beam_decode的出来
    """

    # 调用batcher里的batcher->batch_generator函数 生成器example_generator mode == "test" 207行开始
    dataset = batcher(vocab, params)
    # 测试集的数量
    sample_size = params['sample_size']
    steps_epoch = sample_size // params["batch_size"] + 1
    results = []
    for i in tqdm(range(steps_epoch)):
        enc_data, _ = next(iter(dataset))

        # 如果为TRUE进行贪心搜索 否则BEAM SEARCH
        if params['greedy_decode']:
            # print("-----------------greedy_decode 模式-----------------")
            results += batch_greedy_decode(model, enc_data, vocab, params)
        else:
            # print("-----------------beam_decode 模式-----------------")
            # print(enc_data["enc_input"][0])
            # print(enc_data["enc_input"][1])
            # 需要beam sezi=batch size 输入时候相当于遍历一个个X 去进行搜索
            for row in range(params['batch_size']):
                batch = [
                    enc_data["enc_input"][row]
                    for _ in range(params['beam_size'])
                ]
                best_hyp = beam_decode(model, batch, vocab, params)
                results.append(best_hyp.abstract)

    # batch遍历完成 保存测试结果
    results = list(map(lambda x: x.replace(" ", ""), results))
    # 保存结果 AutoMaster_TestSet.csv
    save_predict_result(results, params)

    # save_predict_result(results, params)
    print('save beam search result to :{}'.format(params['test_x_dir']))