Esempio n. 1
0
def test_graph_updater_obs_gen_forward(batch_size, obs_len, prev_action_len,
                                       rnn_prev_hidden, training):
    g = GraphUpdaterObsGen()
    g.train(training)
    episode_data = {
        "obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "obs_mask":
        torch.randint(2, (batch_size, obs_len)).float(),
        "prev_action_word_ids":
        torch.randint(g.num_words, (batch_size, prev_action_len)),
        "prev_action_mask":
        torch.randint(2, (batch_size, prev_action_len)).float(),
        "groundtruth_obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
    }
    results = g(
        episode_data,
        rnn_prev_hidden=torch.rand(batch_size, g.hparams.hidden_dim)
        if rnn_prev_hidden else None,
    )
    assert results["h_t"].size() == (batch_size, g.hparams.hidden_dim)
    assert results["batch_loss"].size() == (batch_size, )
    if not training:
        assert results["pred_obs_word_ids"].size() == (batch_size, obs_len)
        # decoded_obs_word_ids has variable lengths
        assert results["decoded_obs_word_ids"].size(0) == batch_size
        assert results["decoded_obs_word_ids"].ndim == 2
Esempio n. 2
0
def test_graph_updater_obs_gen_process_batch(batch_size, obs_len,
                                             prev_action_len, max_episode_len,
                                             hidden, training):
    g = GraphUpdaterObsGen()
    g.train(training)
    batch = [{
        "obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "obs_mask":
        torch.randint(2, (batch_size, obs_len)).float(),
        "prev_action_word_ids":
        torch.randint(g.num_words, (batch_size, prev_action_len)),
        "prev_action_mask":
        torch.randint(2, (batch_size, prev_action_len)).float(),
        "groundtruth_obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "step_mask":
        torch.randint(2, (batch_size, )).float(),
    } for _ in range(max_episode_len)]
    h_t = torch.rand(batch_size, g.hparams.hidden_dim) if hidden else None
    results = g.process_batch(batch, h_t=h_t)
    assert len(results["losses"]) == max_episode_len
    assert all(loss.ndim == 0 for loss in results["losses"])
    assert len(results["hiddens"]) == max_episode_len
    assert all(hidden.size() == (batch_size, g.hparams.hidden_dim)
               for hidden in results["hiddens"])
    if not training:
        assert len(results["preds"]) == max_episode_len
        assert all(pred.size() == (batch_size, obs_len)
                   for pred in results["preds"])
        assert len(results["decoded"]) == max_episode_len
        assert all(dec.ndim == 2 for dec in results["decoded"])
        assert all(dec.size(0) == batch_size for dec in results["decoded"])
        assert len(results["f1s"]) <= max_episode_len * batch_size
        assert all(f1.ndim == 0 for f1 in results["f1s"])
Esempio n. 3
0
def test_graph_updater_obs_gen_greedy_decode(batch_size, num_node,
                                             prev_action_len):
    g = GraphUpdaterObsGen()
    decoded = g.greedy_decode(
        torch.rand(batch_size, num_node, g.hparams.hidden_dim),
        torch.rand(batch_size, prev_action_len, g.hparams.hidden_dim),
        torch.randint(2, (batch_size, prev_action_len)).float(),
    )
    assert decoded.ndim == 2
    assert decoded.size(0) == batch_size
    # [BOS] + max_decode_len
    assert decoded.size(1) <= g.hparams.max_decode_len + 1
    # Always start with BOS
    assert decoded[:, 0].equal(
        torch.tensor([g.preprocessor.word_to_id(BOS)] * batch_size))
Esempio n. 4
0
def agent_simple_words():
    preprocessor = SpacyPreprocessor(
        [PAD, UNK, "action", "1", "2", "3", "examine", "cookbook", "table"])
    return Agent(
        GraphUpdaterObsGen().graph_updater,
        GATADoubleDQN().action_selector,
        preprocessor,
    )
Esempio n. 5
0
def agent():
    graph_updater_obs_gen = GraphUpdaterObsGen(
        word_vocab_path="vocabs/word_vocab.txt")
    return Agent(
        graph_updater_obs_gen.graph_updater,
        GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector,
        graph_updater_obs_gen.preprocessor,
    )
Esempio n. 6
0
def eps_greedy_agent():
    graph_updater_obs_gen = GraphUpdaterObsGen(
        word_vocab_path="vocabs/word_vocab.txt")
    return EpsilonGreedyAgent(
        graph_updater_obs_gen.graph_updater,
        GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector,
        graph_updater_obs_gen.preprocessor,
        1.0,
        0.1,
        20,
    )
Esempio n. 7
0
def test_graph_updater_obs_gen_default_init():
    g = GraphUpdaterObsGen()
    default_word_vocab = [PAD, UNK, BOS, EOS]
    assert g.preprocessor.word_vocab == default_word_vocab
    assert g.graph_updater.word_embeddings[0].weight.size() == (
        len(default_word_vocab),
        g.hparams.word_emb_dim,
    )

    # default node_vocab = ['node']
    assert g.graph_updater.node_name_word_ids.size() == (len(g.node_vocab), 1)
    assert g.graph_updater.node_name_mask.size() == (len(g.node_vocab), 1)

    # default relation_vocab = ['relation', 'relation reverse']
    assert g.graph_updater.rel_name_word_ids.size() == (len(g.relation_vocab),
                                                        2)
    assert g.graph_updater.rel_name_mask.size() == (len(g.relation_vocab), 2)
Esempio n. 8
0
def test_learning_rate_warmup(step, multiplier):
    g = GraphUpdaterObsGen(steps_for_lr_warmup=16)
    assert g.learning_rate_warmup(step) == multiplier
Esempio n. 9
0
def main(cfg: DictConfig) -> None:
    print(f"Training with the following config:\n{OmegaConf.to_yaml(cfg)}")

    # seed
    pl.seed_everything(cfg.seed)

    # trainer
    trainer_config = OmegaConf.to_container(cfg.pl_trainer, resolve=True)
    assert isinstance(trainer_config, dict)
    trainer_config["logger"] = instantiate(
        cfg.logger) if "logger" in cfg else True
    val_monitor = "val_avg_game_normalized_rewards"
    train_monitor = "train_avg_game_normalized_rewards"
    trainer_config["callbacks"] = [
        RLEarlyStopping(
            val_monitor,
            train_monitor,
            cfg.train.early_stop_threshold,
            patience=cfg.train.early_stop_patience,
        ),
        EqualModelCheckpoint(
            monitor=val_monitor,
            mode="max",
            filename="{epoch}-{step}-{val_avg_game_normalized_rewards:.2f}",
        ),
    ]
    if isinstance(trainer_config["logger"], WandbLogger):
        trainer_config["callbacks"].append(WandbSaveCallback())
    trainer = pl.Trainer(**trainer_config)

    # instantiate the lightning module
    if not cfg.eval.test_only:
        lm_model_config = OmegaConf.to_container(cfg.model, resolve=True)
        assert isinstance(lm_model_config, dict)
        if cfg.model.pretrained_graph_updater is not None:
            graph_updater_obs_gen = GraphUpdaterObsGen.load_from_checkpoint(
                to_absolute_path(cfg.model.pretrained_graph_updater.ckpt_path),
                word_vocab_path=cfg.model.pretrained_graph_updater.
                word_vocab_path,
                node_vocab_path=cfg.model.pretrained_graph_updater.
                node_vocab_path,
                relation_vocab_path=(
                    cfg.model.pretrained_graph_updater.relation_vocab_path),
            )
            lm_model_config[
                "pretrained_graph_updater"] = graph_updater_obs_gen.graph_updater
        lm = GATADoubleDQN(**lm_model_config, **cfg.train, **cfg.data)

        # fit
        trainer.fit(lm)

        # test
        trainer.test()
    else:
        assert (cfg.eval.checkpoint_path
                is not None), "missing checkpoint path for testing"
        parsed = urlparse(cfg.eval.checkpoint_path)
        if parsed.scheme == "":
            # local path
            ckpt_path = to_absolute_path(cfg.eval.checkpoint_path)
        else:
            # remote path
            ckpt_path = cfg.eval.checkpoint_path

        model = GATADoubleDQN.load_from_checkpoint(
            ckpt_path,
            base_data_dir=cfg.data.base_data_dir,
            word_vocab_path=cfg.model.word_vocab_path,
            node_vocab_path=cfg.model.node_vocab_path,
            relation_vocab_path=cfg.model.relation_vocab_path,
        )
        trainer.test(model=model)