def train(params): # GPU资源配置 config_gpu() # 读取vocab训练 print("Building vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) # 构建模型 print("Building the model ...") # model = Seq2Seq(params) model = PGN(params) # print("Creating the batcher ...") # dataset = batcher(params["train_seg_x_dir"], params["train_seg_y_dir"], vocab, params) # print('dataset is ', dataset) # 获取保存管理者 print("Creating the checkpoint manager") checkpoint = tf.train.Checkpoint(Seq2Seq=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, CKPT_DIR, max_to_keep=5) checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Restored from {}".format(checkpoint_manager.latest_checkpoint)) else: print("Initializing from scratch.") # 训练模型 print("Starting the training ...") train_model(model, vocab, params, checkpoint_manager)
def train(params): # GPU资源配置 config_gpu() # 读取vocab训练 vocab = Vocab(params["vocab_path"], params["vocab_size"]) params['vocab_size'] = vocab.count params["trained_epoch"] = get_train_msg() params["learning_rate"] *= np.power(0.9, params["trained_epoch"]) # 构建模型 print("Building the model ...") model = Seq2Seq(params) # 获取保存管理者 checkpoint = tf.train.Checkpoint(Seq2Seq=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, SEQ2SEQ_CKPT, max_to_keep=5) checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Restored from {}".format(checkpoint_manager.latest_checkpoint)) else: print("Initializing from scratch.") # 训练模型 print("开始训练模型..") print("trained_epoch:", params["trained_epoch"]) print("mode:", params["mode"]) print("epochs:", params["epochs"]) print("batch_size:", params["batch_size"]) print("max_enc_len:", params["max_enc_len"]) print("max_dec_len:", params["max_dec_len"]) print("learning_rate:", params["learning_rate"]) train_model(model, vocab, params, checkpoint_manager)
def test(params): assert params["mode"].lower() in [ "test", "eval" ], "change training mode to 'test' or 'eval'" assert params["beam_size"] == params[ "batch_size"], "Beam size must be equal to batch_size, change the params" # GPU资源配置 config_gpu() print("Test the model ...") model = PGN(params) print("Creating the vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) # ds = batcher(vocab, params) print("Creating the checkpoint manager") checkpoint = tf.train.Checkpoint(PGN=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, PGN_CKPT, max_to_keep=5) # checkpoint_manager = tf.train.CheckpointManager(checkpoint, TEMP_CKPT, max_to_keep=5) # temp_ckpt = os.path.join(TEMP_CKPT, "ckpt-5") # checkpoint.restore(temp_ckpt) checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Restored from {}".format(checkpoint_manager.latest_checkpoint)) else: print("Initializing from scratch.") print("Model restored") if params['greedy_decode']: params['batch_size'] = 512 results = predict_result(model, params, vocab, params['result_save_path']) else: b = beam_test_batch_generator(params["beam_size"]) results = [] for batch in b: best_hyp = beam_decode(model, batch, vocab, params) results.append(best_hyp.abstract) save_predict_result(results, params['result_save_path']) print('save result to :{}'.format(params['result_save_path'])) return results
{ "dec_input" : entry["dec_input"], "dec_target": entry["target"], "dec_len": entry["dec_len"], "abstract": entry["abstract"], "dec_mask": entry["dec_mask"]}) dataset = dataset.map(update) return dataset def batcher(vocab, params): # if params['mode'] == 'train' and params['load_batch_train_data']: # dataset = load_batch_generator(params) # else: dataset = batch_generator(example_generator, params, vocab) return dataset if __name__ == "__main__": # GPU资源配置 config_gpu() # 获取参数 params = get_params() params['mode'] = 'train' # vocab 对象 vocab = Vocab(VOCAB_PAD) ds = batcher(vocab, params) batch = next(iter(ds.take(1)))