コード例 #1
0
def train(params):
    # GPU资源配置
    config_gpu()
    # 读取vocab训练
    print("Building vocab ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])

    # 构建模型
    print("Building the model ...")
    # model = Seq2Seq(params)
    model = PGN(params)

    # print("Creating the batcher ...")
    # dataset = batcher(params["train_seg_x_dir"], params["train_seg_y_dir"], vocab, params)
    # print('dataset is ', dataset)

    # 获取保存管理者
    print("Creating the checkpoint manager")
    checkpoint = tf.train.Checkpoint(Seq2Seq=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, CKPT_DIR, max_to_keep=5)
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    # 训练模型
    print("Starting the training ...")
    train_model(model, vocab, params, checkpoint_manager)
コード例 #2
0
def train(params):
    # GPU资源配置
    config_gpu()
    # 读取vocab训练
    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    params['vocab_size'] = vocab.count
    params["trained_epoch"] = get_train_msg()
    params["learning_rate"] *= np.power(0.9, params["trained_epoch"])

    # 构建模型
    print("Building the model ...")
    model = Seq2Seq(params)
    # 获取保存管理者
    checkpoint = tf.train.Checkpoint(Seq2Seq=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    SEQ2SEQ_CKPT,
                                                    max_to_keep=5)

    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    # 训练模型
    print("开始训练模型..")
    print("trained_epoch:", params["trained_epoch"])
    print("mode:", params["mode"])
    print("epochs:", params["epochs"])
    print("batch_size:", params["batch_size"])
    print("max_enc_len:", params["max_enc_len"])
    print("max_dec_len:", params["max_dec_len"])
    print("learning_rate:", params["learning_rate"])

    train_model(model, vocab, params, checkpoint_manager)
コード例 #3
0
def test(params):
    assert params["mode"].lower() in [
        "test", "eval"
    ], "change training mode to 'test' or 'eval'"
    assert params["beam_size"] == params[
        "batch_size"], "Beam size must be equal to batch_size, change the params"
    # GPU资源配置
    config_gpu()

    print("Test the model ...")

    model = PGN(params)

    print("Creating the vocab ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    # ds = batcher(vocab, params)

    print("Creating the checkpoint manager")
    checkpoint = tf.train.Checkpoint(PGN=model)

    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    PGN_CKPT,
                                                    max_to_keep=5)

    # checkpoint_manager = tf.train.CheckpointManager(checkpoint, TEMP_CKPT, max_to_keep=5)
    # temp_ckpt = os.path.join(TEMP_CKPT, "ckpt-5")
    # checkpoint.restore(temp_ckpt)

    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    print("Model restored")

    if params['greedy_decode']:
        params['batch_size'] = 512
        results = predict_result(model, params, vocab,
                                 params['result_save_path'])
    else:
        b = beam_test_batch_generator(params["beam_size"])
        results = []
        for batch in b:
            best_hyp = beam_decode(model, batch, vocab, params)
            results.append(best_hyp.abstract)
        save_predict_result(results, params['result_save_path'])
        print('save result to :{}'.format(params['result_save_path']))

    return results
コード例 #4
0
ファイル: batcher.py プロジェクト: dulingkang/NLP-QA-Reason
                {
                 "dec_input" : entry["dec_input"],
                 "dec_target": entry["target"],
                 "dec_len": entry["dec_len"],
                 "abstract": entry["abstract"],
                 "dec_mask": entry["dec_mask"]})
    dataset = dataset.map(update)
    return dataset


def batcher(vocab, params):
    # if params['mode'] == 'train' and params['load_batch_train_data']:
    #     dataset = load_batch_generator(params)
    # else:
    dataset = batch_generator(example_generator, params, vocab)
    return dataset


if __name__ == "__main__":
    # GPU资源配置
    config_gpu()
    # 获取参数
    params = get_params()
    params['mode'] = 'train'
    # vocab 对象
    vocab = Vocab(VOCAB_PAD)
    ds = batcher(vocab, params)

    batch = next(iter(ds.take(1)))