示例#1
0
def test_gata_double_dqn_update_target_action_selector():
    gata_ddqn = GATADoubleDQN()
    # scramble layers in the online action selector and update
    gata_ddqn.action_selector.node_name_word_ids.fill_(42)
    gata_ddqn.action_selector.node_embeddings = nn.Embedding(
        gata_ddqn.num_nodes, gata_ddqn.hparams.node_emb_dim)

    # make sure the weights are the same after updating
    gata_ddqn.update_target_action_selector()
    for online, target in zip(
            gata_ddqn.action_selector.parameters(),
            gata_ddqn.target_action_selector.parameters(),
    ):
        assert online.equal(target)
示例#2
0
def test_rl_early_stopping():
    gata_double_dqn = GATADoubleDQN()
    trainer = Trainer()
    es = RLEarlyStopping("val_monitor", "train_monitor", 0.95, patience=3)

    # if val score and train score are all below the threshold 0.95, don't stop
    trainer.callback_metrics = {"val_monitor": 0.1, "train_monitor": 0.1}
    es._run_early_stopping_check(trainer, gata_double_dqn)
    assert not trainer.should_stop

    # if val score is 1.0 and train score is above the threshold, stop
    trainer.callback_metrics = {"val_monitor": 1.0, "train_monitor": 0.95}
    trainer.current_epoch = 1
    es._run_early_stopping_check(trainer, gata_double_dqn)
    assert trainer.should_stop
    assert es.stopped_epoch == 1

    # if train score is above the threshold for `patience` times,
    # but val score is not 1.0, stop
    trainer.should_stop = False
    es.wait_count = 0
    es.stopped_epoch = 0
    for i in range(3):
        trainer.current_epoch = i
        trainer.callback_metrics = {"val_monitor": 0.9, "train_monitor": 0.95}
        es._run_early_stopping_check(trainer, gata_double_dqn)
        if i == 2:
            assert trainer.should_stop
            assert es.stopped_epoch == 2
        else:
            assert not trainer.should_stop
            assert es.stopped_epoch == 0
示例#3
0
def test_gata_double_dqn_forward(
    batch_size,
    obs_len,
    prev_action_len,
    num_action_cands,
    action_cand_len,
):
    gata_ddqn = GATADoubleDQN()
    results = gata_ddqn(
        torch.randint(gata_ddqn.num_words, (batch_size, obs_len)),
        increasing_mask(batch_size, obs_len),
        torch.randint(gata_ddqn.num_words, (batch_size, prev_action_len)),
        increasing_mask(batch_size, prev_action_len),
        torch.rand(batch_size, gata_ddqn.hparams.hidden_dim),
        torch.randint(gata_ddqn.num_words,
                      (batch_size, num_action_cands, action_cand_len)),
        increasing_mask(batch_size * num_action_cands,
                        action_cand_len).view(batch_size, num_action_cands,
                                              action_cand_len),
        increasing_mask(batch_size, num_action_cands),
    )
    assert results["action_scores"].size() == (batch_size, num_action_cands)
    assert results["rnn_curr_hidden"].size() == (
        batch_size,
        gata_ddqn.hparams.hidden_dim,
    )
    assert results["current_graph"].size() == (
        batch_size,
        gata_ddqn.num_relations,
        gata_ddqn.num_nodes,
        gata_ddqn.num_nodes,
    )
示例#4
0
def agent_simple_words():
    preprocessor = SpacyPreprocessor(
        [PAD, UNK, "action", "1", "2", "3", "examine", "cookbook", "table"])
    return Agent(
        GraphUpdaterObsGen().graph_updater,
        GATADoubleDQN().action_selector,
        preprocessor,
    )
示例#5
0
def agent():
    graph_updater_obs_gen = GraphUpdaterObsGen(
        word_vocab_path="vocabs/word_vocab.txt")
    return Agent(
        graph_updater_obs_gen.graph_updater,
        GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector,
        graph_updater_obs_gen.preprocessor,
    )
示例#6
0
def replay_buffer_gata_double_dqn():
    return GATADoubleDQN(
        train_game_batch_size=2,
        train_max_episode_steps=5,
        replay_buffer_populate_episodes=10,
        yield_step_freq=10,
        replay_buffer_capacity=20,
        train_sample_batch_size=4,
    )
示例#7
0
def eps_greedy_agent():
    gata_double_dqn = GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt")
    return EpsilonGreedyAgent(
        gata_double_dqn.graph_updater,
        gata_double_dqn.action_selector,
        gata_double_dqn.preprocessor,
        0.1,
        1.0,
        20,
    )
示例#8
0
def eps_greedy_agent():
    graph_updater_obs_gen = GraphUpdaterObsGen(
        word_vocab_path="vocabs/word_vocab.txt")
    return EpsilonGreedyAgent(
        graph_updater_obs_gen.graph_updater,
        GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector,
        graph_updater_obs_gen.preprocessor,
        1.0,
        0.1,
        20,
    )
示例#9
0
from train_gata import request_infos_for_eval, GATADoubleDQN
from agent import Agent


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("game_file")
    parser.add_argument("gata_double_dqn_ckpt")
    args = parser.parse_args()

    gata_double_dqn = GATADoubleDQN.load_from_checkpoint(
        args.gata_double_dqn_ckpt,
        word_vocab_path="vocabs/word_vocab.txt",
        node_vocab_path="vocabs/node_vocab.txt",
        relation_vocab_path="vocabs/relation_vocab.txt",
    )
    agent = Agent(
        gata_double_dqn.graph_updater,
        gata_double_dqn.action_selector,
        gata_double_dqn.preprocessor,
    )

    env_id = textworld.gym.register_game(
        args.game_file, request_infos=request_infos_for_eval()
    )
    env = gym.make(env_id)

    prev_actions = None
    rnn_prev_hidden = None
示例#10
0
def test_gata_double_dqn_get_q_values(action_scores, action_mask, actions_idx,
                                      expected):
    assert GATADoubleDQN.get_q_values(action_scores, action_mask,
                                      actions_idx).equal(expected)
示例#11
0
def test_gata_double_dqn_default_init():
    gata_ddqn = GATADoubleDQN()

    # train_env is initialized with the test games
    assert len(gata_ddqn.train_env.gamefiles) == 2
    assert gata_ddqn.train_env.request_infos == request_infos_for_train()
    assert gata_ddqn.train_env.batch_size == gata_ddqn.hparams.train_game_batch_size
    assert gata_ddqn.train_env.spec.id.split("-")[1] == "train"

    # val_env is initialized with the test games
    assert len(gata_ddqn.val_env.gamefiles) == 2
    assert gata_ddqn.val_env.request_infos == request_infos_for_eval()
    assert gata_ddqn.val_env.batch_size == gata_ddqn.hparams.eval_game_batch_size
    assert gata_ddqn.val_env.spec.id.split("-")[1] == "val"

    # test_env is initialized with the test games
    assert len(gata_ddqn.test_env.gamefiles) == 2
    assert gata_ddqn.test_env.request_infos == request_infos_for_eval()
    assert gata_ddqn.test_env.batch_size == gata_ddqn.hparams.eval_game_batch_size
    assert gata_ddqn.test_env.spec.id.split("-")[1] == "test"

    # default words
    default_word_vocab = [PAD, UNK, BOS, EOS]
    assert gata_ddqn.preprocessor.word_vocab == default_word_vocab
    assert gata_ddqn.graph_updater.word_embeddings[0].weight.size() == (
        len(default_word_vocab),
        gata_ddqn.hparams.word_emb_dim,
    )

    # default node_vocab = ['node']
    assert gata_ddqn.graph_updater.node_name_word_ids.size() == (
        len(gata_ddqn.node_vocab),
        1,
    )
    assert gata_ddqn.graph_updater.node_name_mask.size() == (
        len(gata_ddqn.node_vocab),
        1,
    )

    # default relation_vocab = ['relation', 'relation reverse']
    assert gata_ddqn.graph_updater.rel_name_word_ids.size() == (
        len(gata_ddqn.relation_vocab),
        2,
    )
    assert gata_ddqn.graph_updater.rel_name_mask.size() == (
        len(gata_ddqn.relation_vocab),
        2,
    )

    # online action selector is train mode
    assert gata_ddqn.action_selector.training

    # target action selector is in train mode
    assert gata_ddqn.target_action_selector.training
    # and frozen
    for param in gata_ddqn.target_action_selector.parameters():
        assert param.requires_grad is False

    # online and target action selectors should be initialized to be the same
    for online, target in zip(
            gata_ddqn.action_selector.parameters(),
            gata_ddqn.target_action_selector.parameters(),
    ):
        assert online.equal(target)

    # graph updater is in eval mode
    assert not gata_ddqn.graph_updater.training
    # and frozen
    for param in gata_ddqn.graph_updater.parameters():
        assert param.requires_grad is False