def test_gata_double_dqn_update_target_action_selector(): gata_ddqn = GATADoubleDQN() # scramble layers in the online action selector and update gata_ddqn.action_selector.node_name_word_ids.fill_(42) gata_ddqn.action_selector.node_embeddings = nn.Embedding( gata_ddqn.num_nodes, gata_ddqn.hparams.node_emb_dim) # make sure the weights are the same after updating gata_ddqn.update_target_action_selector() for online, target in zip( gata_ddqn.action_selector.parameters(), gata_ddqn.target_action_selector.parameters(), ): assert online.equal(target)
def test_rl_early_stopping(): gata_double_dqn = GATADoubleDQN() trainer = Trainer() es = RLEarlyStopping("val_monitor", "train_monitor", 0.95, patience=3) # if val score and train score are all below the threshold 0.95, don't stop trainer.callback_metrics = {"val_monitor": 0.1, "train_monitor": 0.1} es._run_early_stopping_check(trainer, gata_double_dqn) assert not trainer.should_stop # if val score is 1.0 and train score is above the threshold, stop trainer.callback_metrics = {"val_monitor": 1.0, "train_monitor": 0.95} trainer.current_epoch = 1 es._run_early_stopping_check(trainer, gata_double_dqn) assert trainer.should_stop assert es.stopped_epoch == 1 # if train score is above the threshold for `patience` times, # but val score is not 1.0, stop trainer.should_stop = False es.wait_count = 0 es.stopped_epoch = 0 for i in range(3): trainer.current_epoch = i trainer.callback_metrics = {"val_monitor": 0.9, "train_monitor": 0.95} es._run_early_stopping_check(trainer, gata_double_dqn) if i == 2: assert trainer.should_stop assert es.stopped_epoch == 2 else: assert not trainer.should_stop assert es.stopped_epoch == 0
def test_gata_double_dqn_forward( batch_size, obs_len, prev_action_len, num_action_cands, action_cand_len, ): gata_ddqn = GATADoubleDQN() results = gata_ddqn( torch.randint(gata_ddqn.num_words, (batch_size, obs_len)), increasing_mask(batch_size, obs_len), torch.randint(gata_ddqn.num_words, (batch_size, prev_action_len)), increasing_mask(batch_size, prev_action_len), torch.rand(batch_size, gata_ddqn.hparams.hidden_dim), torch.randint(gata_ddqn.num_words, (batch_size, num_action_cands, action_cand_len)), increasing_mask(batch_size * num_action_cands, action_cand_len).view(batch_size, num_action_cands, action_cand_len), increasing_mask(batch_size, num_action_cands), ) assert results["action_scores"].size() == (batch_size, num_action_cands) assert results["rnn_curr_hidden"].size() == ( batch_size, gata_ddqn.hparams.hidden_dim, ) assert results["current_graph"].size() == ( batch_size, gata_ddqn.num_relations, gata_ddqn.num_nodes, gata_ddqn.num_nodes, )
def agent_simple_words(): preprocessor = SpacyPreprocessor( [PAD, UNK, "action", "1", "2", "3", "examine", "cookbook", "table"]) return Agent( GraphUpdaterObsGen().graph_updater, GATADoubleDQN().action_selector, preprocessor, )
def agent(): graph_updater_obs_gen = GraphUpdaterObsGen( word_vocab_path="vocabs/word_vocab.txt") return Agent( graph_updater_obs_gen.graph_updater, GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector, graph_updater_obs_gen.preprocessor, )
def replay_buffer_gata_double_dqn(): return GATADoubleDQN( train_game_batch_size=2, train_max_episode_steps=5, replay_buffer_populate_episodes=10, yield_step_freq=10, replay_buffer_capacity=20, train_sample_batch_size=4, )
def eps_greedy_agent(): gata_double_dqn = GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt") return EpsilonGreedyAgent( gata_double_dqn.graph_updater, gata_double_dqn.action_selector, gata_double_dqn.preprocessor, 0.1, 1.0, 20, )
def eps_greedy_agent(): graph_updater_obs_gen = GraphUpdaterObsGen( word_vocab_path="vocabs/word_vocab.txt") return EpsilonGreedyAgent( graph_updater_obs_gen.graph_updater, GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector, graph_updater_obs_gen.preprocessor, 1.0, 0.1, 20, )
from train_gata import request_infos_for_eval, GATADoubleDQN from agent import Agent if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("game_file") parser.add_argument("gata_double_dqn_ckpt") args = parser.parse_args() gata_double_dqn = GATADoubleDQN.load_from_checkpoint( args.gata_double_dqn_ckpt, word_vocab_path="vocabs/word_vocab.txt", node_vocab_path="vocabs/node_vocab.txt", relation_vocab_path="vocabs/relation_vocab.txt", ) agent = Agent( gata_double_dqn.graph_updater, gata_double_dqn.action_selector, gata_double_dqn.preprocessor, ) env_id = textworld.gym.register_game( args.game_file, request_infos=request_infos_for_eval() ) env = gym.make(env_id) prev_actions = None rnn_prev_hidden = None
def test_gata_double_dqn_get_q_values(action_scores, action_mask, actions_idx, expected): assert GATADoubleDQN.get_q_values(action_scores, action_mask, actions_idx).equal(expected)
def test_gata_double_dqn_default_init(): gata_ddqn = GATADoubleDQN() # train_env is initialized with the test games assert len(gata_ddqn.train_env.gamefiles) == 2 assert gata_ddqn.train_env.request_infos == request_infos_for_train() assert gata_ddqn.train_env.batch_size == gata_ddqn.hparams.train_game_batch_size assert gata_ddqn.train_env.spec.id.split("-")[1] == "train" # val_env is initialized with the test games assert len(gata_ddqn.val_env.gamefiles) == 2 assert gata_ddqn.val_env.request_infos == request_infos_for_eval() assert gata_ddqn.val_env.batch_size == gata_ddqn.hparams.eval_game_batch_size assert gata_ddqn.val_env.spec.id.split("-")[1] == "val" # test_env is initialized with the test games assert len(gata_ddqn.test_env.gamefiles) == 2 assert gata_ddqn.test_env.request_infos == request_infos_for_eval() assert gata_ddqn.test_env.batch_size == gata_ddqn.hparams.eval_game_batch_size assert gata_ddqn.test_env.spec.id.split("-")[1] == "test" # default words default_word_vocab = [PAD, UNK, BOS, EOS] assert gata_ddqn.preprocessor.word_vocab == default_word_vocab assert gata_ddqn.graph_updater.word_embeddings[0].weight.size() == ( len(default_word_vocab), gata_ddqn.hparams.word_emb_dim, ) # default node_vocab = ['node'] assert gata_ddqn.graph_updater.node_name_word_ids.size() == ( len(gata_ddqn.node_vocab), 1, ) assert gata_ddqn.graph_updater.node_name_mask.size() == ( len(gata_ddqn.node_vocab), 1, ) # default relation_vocab = ['relation', 'relation reverse'] assert gata_ddqn.graph_updater.rel_name_word_ids.size() == ( len(gata_ddqn.relation_vocab), 2, ) assert gata_ddqn.graph_updater.rel_name_mask.size() == ( len(gata_ddqn.relation_vocab), 2, ) # online action selector is train mode assert gata_ddqn.action_selector.training # target action selector is in train mode assert gata_ddqn.target_action_selector.training # and frozen for param in gata_ddqn.target_action_selector.parameters(): assert param.requires_grad is False # online and target action selectors should be initialized to be the same for online, target in zip( gata_ddqn.action_selector.parameters(), gata_ddqn.target_action_selector.parameters(), ): assert online.equal(target) # graph updater is in eval mode assert not gata_ddqn.graph_updater.training # and frozen for param in gata_ddqn.graph_updater.parameters(): assert param.requires_grad is False