def test(params): assert params["mode"].lower() == "test", "change training mode to 'test' or 'eval'" # assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params" vocab = Vocab(params["vocab_path"], params["vocab_size"]) params['vocab_size'] = vocab.count print("Creating the batcher ...") dataset, params['steps_per_epoch'] = batcher(vocab, params) print("Building the model ...") model = PGN_TRANSFORMER(params) print("Creating the checkpoint manager") ckpt = tf.train.Checkpoint(PGN_TRANSFORMER=model) ckpt_manager = tf.train.CheckpointManager(ckpt, params['transformer_model_dir'], max_to_keep=5) # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint # path = ckpt_manager.latest_checkpoint ckpt.restore(ckpt_manager.latest_checkpoint) print("Model restored") for batch in dataset: if params['decode_mode'] == "greedy": yield greedy_decode(model, dataset, vocab, params) else: yield beam_decode(model, batch, vocab, params, params['print_info'])
def train(params): assert params["mode"].lower() == "train", "change training mode to 'train'" print("Creating the vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) params['vocab_size'] = vocab.count print("Creating the batcher ...") batch, params['steps_per_epoch'] = batcher(vocab, params) print("Building the model ...") model = PGN_TRANSFORMER(params) print("Creating the checkpoint manager") checkpoint = tf.train.Checkpoint(PGN_TRANSFORMER=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['transformer_model_dir'], max_to_keep=5) checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Restored from {}".format(checkpoint_manager.latest_checkpoint)) params["trained_epoch"] = int(checkpoint_manager.latest_checkpoint[-1]) else: print("Initializing from scratch.") params["trained_epoch"] = 1 print("Starting the training ...") train_model(model, batch, params, checkpoint_manager)
def train(params): # GPU资源配置 config_gpu() # 读取vocab训练 print("Building the model ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) params['vocab_size'] = vocab.count # 构建模型 print("Building the model ...") model = PGN(params) print("Creating the batcher ...") train_dataset, params['train_steps_per_epoch'] = batcher(vocab, params) params["mode"] = 'val' val_dataset, params['val_steps_per_epoch'] = batcher(vocab, params) params["mode"] = 'train' # 获取保存管理者 print("Creating the checkpoint manager") checkpoint = tf.train.Checkpoint(PGN=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['checkpoint_dir'], max_to_keep=5) checkpoint.restore(checkpoint_manager.latest_checkpoint) if checkpoint_manager.latest_checkpoint: print("Restored from {}".format(checkpoint_manager.latest_checkpoint)) params["trained_epoch"] = int(checkpoint_manager.latest_checkpoint[-1]) else: print("Initializing from scratch.") params["trained_epoch"] = 1 # 学习率衰减 params["learning_rate"] *= np.power(0.95, params["trained_epoch"]) print('learning_rate:{}'.format(params["learning_rate"])) # 训练模型 print("Starting the training ...") train_model(model, train_dataset, val_dataset, params, checkpoint_manager)
def predict_result(model, params, vocab, result_save_path): dataset, params['steps_per_epoch'] = batcher(vocab, params) if params['decode_mode'] == 'beam': results = [] for batch in tqdm(dataset, total=params['steps_per_epoch']): best_hyp = beam_decode(model, batch, vocab, params, print_info=True) results.append(best_hyp.abstract) else: # 预测结果 results = greedy_decode(model, dataset, vocab, params) get_rouge(results) # 保存结果 if not os.path.exists(os.path.dirname(result_save_path)): os.makedirs(os.path.dirname(result_save_path)) # save_predict_result(results, result_save_path) return results