def test(params): assert params["mode"].lower() in ["test","eval"], "change training mode to 'test' or 'eval'" # assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params" print("Building the model ...") model = PGN(params) print("Creating the vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) print("Creating the batcher ...") b = batcher(vocab, params) print("Creating the checkpoint manager") checkpoint_dir = "{}/checkpoint".format(params["pgn_model_dir"]) ckpt = tf.train.Checkpoint(PGN=model) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5) # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint # path = ckpt_manager.latest_checkpoint ckpt.restore(ckpt_manager.latest_checkpoint) print("Model restored") for batch in b: # print(batch[0]["enc_input"]) yield beam_decode(model, batch, vocab, params)
def train(params): assert params["mode"].lower() == "train", "change training mode to 'train'" vocab = Vocab(params["vocab_path"], params["vocab_size"]) print('true vocab is ', vocab) print("Creating the batcher ...") b = batcher(vocab, params) print("Building the model ...") model = PGN(params) print("Creating the vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) print('true vocab is ', vocab) print("Creating the checkpoint manager") checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: print("Restored from {}".format(ckpt_manager.latest_checkpoint)) else: print("Initializing from scratch.") print("Starting the training ...") train_model(model, b, params, ckpt, ckpt_manager)
def test(params): assert params["mode"].lower( ) == "test", "change training mode to 'test' or 'eval'" assert params["beam_size"] == params[ "batch_size"], "Beam size must be equal to batch_size, change the params" print("Building the model ...") model = PGN(params) print("Creating the vocab ...") vocab = Vocab(params["vocab_path"], params["vocab_size"]) print("Creating the batcher ...") b = batcher(vocab, params) print("Creating the checkpoint manager") checkpoint_dir = "{}".format(params["model_dir"]) print('checkpoint_dir is ', checkpoint_dir) ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5) # path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint # path = ckpt_manager.latest_checkpoint # print('path is ', path) ckpt.restore(ckpt_manager.latest_checkpoint) print("Model restored") for batch in b: yield batch_greedy_decode(model, batch, vocab, params)