コード例 #1
0
def test(params):
    assert params["mode"].lower() == "test", "change training mode to 'test' or 'eval'"
    # assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params"

    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    params['vocab_size'] = vocab.count

    print("Creating the batcher ...")
    dataset, params['steps_per_epoch'] = batcher(vocab, params)

    print("Building the model ...")
    model = PGN_TRANSFORMER(params)

    print("Creating the checkpoint manager")
    ckpt = tf.train.Checkpoint(PGN_TRANSFORMER=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, params['transformer_model_dir'], max_to_keep=5)

    # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint
    # path = ckpt_manager.latest_checkpoint
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Model restored")

    for batch in dataset:
        if params['decode_mode'] == "greedy":
            yield greedy_decode(model, dataset, vocab, params)
        else:
            yield beam_decode(model, batch, vocab, params, params['print_info'])
コード例 #2
0
def train(params):
    assert params["mode"].lower() == "train", "change training mode to 'train'"

    print("Creating the vocab ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    params['vocab_size'] = vocab.count

    print("Creating the batcher ...")
    batch, params['steps_per_epoch'] = batcher(vocab, params)

    print("Building the model ...")
    model = PGN_TRANSFORMER(params)

    print("Creating the checkpoint manager")

    checkpoint = tf.train.Checkpoint(PGN_TRANSFORMER=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['transformer_model_dir'], max_to_keep=5)

    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
        params["trained_epoch"] = int(checkpoint_manager.latest_checkpoint[-1])
    else:
        print("Initializing from scratch.")
        params["trained_epoch"] = 1

    print("Starting the training ...")
    train_model(model, batch, params, checkpoint_manager)
コード例 #3
0
def train(params):
    # GPU资源配置
    config_gpu()
    # 读取vocab训练
    print("Building the model ...")
    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    params['vocab_size'] = vocab.count

    # 构建模型
    print("Building the model ...")
    model = PGN(params)

    print("Creating the batcher ...")
    train_dataset, params['train_steps_per_epoch'] = batcher(vocab, params)
    params["mode"] = 'val'
    val_dataset, params['val_steps_per_epoch'] = batcher(vocab, params)
    params["mode"] = 'train'

    # 获取保存管理者
    print("Creating the checkpoint manager")
    checkpoint = tf.train.Checkpoint(PGN=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    params['checkpoint_dir'],
                                                    max_to_keep=5)
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
        params["trained_epoch"] = int(checkpoint_manager.latest_checkpoint[-1])
    else:
        print("Initializing from scratch.")
        params["trained_epoch"] = 1

    # 学习率衰减
    params["learning_rate"] *= np.power(0.95, params["trained_epoch"])
    print('learning_rate:{}'.format(params["learning_rate"]))
    # 训练模型
    print("Starting the training ...")

    train_model(model, train_dataset, val_dataset, params, checkpoint_manager)
コード例 #4
0
def predict_result(model, params, vocab, result_save_path):
    dataset, params['steps_per_epoch'] = batcher(vocab, params)

    if params['decode_mode'] == 'beam':
        results = []
        for batch in tqdm(dataset, total=params['steps_per_epoch']):
            best_hyp = beam_decode(model, batch, vocab, params, print_info=True)
            results.append(best_hyp.abstract)
    else:
        # 预测结果
        results = greedy_decode(model, dataset, vocab, params)
    get_rouge(results)
    # 保存结果
    if not os.path.exists(os.path.dirname(result_save_path)):
        os.makedirs(os.path.dirname(result_save_path))
    # save_predict_result(results, result_save_path)

    return results