def test_graph_updater_obs_gen_forward(batch_size, obs_len, prev_action_len, rnn_prev_hidden, training): g = GraphUpdaterObsGen() g.train(training) episode_data = { "obs_word_ids": torch.randint(g.num_words, (batch_size, obs_len)), "obs_mask": torch.randint(2, (batch_size, obs_len)).float(), "prev_action_word_ids": torch.randint(g.num_words, (batch_size, prev_action_len)), "prev_action_mask": torch.randint(2, (batch_size, prev_action_len)).float(), "groundtruth_obs_word_ids": torch.randint(g.num_words, (batch_size, obs_len)), } results = g( episode_data, rnn_prev_hidden=torch.rand(batch_size, g.hparams.hidden_dim) if rnn_prev_hidden else None, ) assert results["h_t"].size() == (batch_size, g.hparams.hidden_dim) assert results["batch_loss"].size() == (batch_size, ) if not training: assert results["pred_obs_word_ids"].size() == (batch_size, obs_len) # decoded_obs_word_ids has variable lengths assert results["decoded_obs_word_ids"].size(0) == batch_size assert results["decoded_obs_word_ids"].ndim == 2
def test_graph_updater_obs_gen_process_batch(batch_size, obs_len, prev_action_len, max_episode_len, hidden, training): g = GraphUpdaterObsGen() g.train(training) batch = [{ "obs_word_ids": torch.randint(g.num_words, (batch_size, obs_len)), "obs_mask": torch.randint(2, (batch_size, obs_len)).float(), "prev_action_word_ids": torch.randint(g.num_words, (batch_size, prev_action_len)), "prev_action_mask": torch.randint(2, (batch_size, prev_action_len)).float(), "groundtruth_obs_word_ids": torch.randint(g.num_words, (batch_size, obs_len)), "step_mask": torch.randint(2, (batch_size, )).float(), } for _ in range(max_episode_len)] h_t = torch.rand(batch_size, g.hparams.hidden_dim) if hidden else None results = g.process_batch(batch, h_t=h_t) assert len(results["losses"]) == max_episode_len assert all(loss.ndim == 0 for loss in results["losses"]) assert len(results["hiddens"]) == max_episode_len assert all(hidden.size() == (batch_size, g.hparams.hidden_dim) for hidden in results["hiddens"]) if not training: assert len(results["preds"]) == max_episode_len assert all(pred.size() == (batch_size, obs_len) for pred in results["preds"]) assert len(results["decoded"]) == max_episode_len assert all(dec.ndim == 2 for dec in results["decoded"]) assert all(dec.size(0) == batch_size for dec in results["decoded"]) assert len(results["f1s"]) <= max_episode_len * batch_size assert all(f1.ndim == 0 for f1 in results["f1s"])
def test_graph_updater_obs_gen_greedy_decode(batch_size, num_node, prev_action_len): g = GraphUpdaterObsGen() decoded = g.greedy_decode( torch.rand(batch_size, num_node, g.hparams.hidden_dim), torch.rand(batch_size, prev_action_len, g.hparams.hidden_dim), torch.randint(2, (batch_size, prev_action_len)).float(), ) assert decoded.ndim == 2 assert decoded.size(0) == batch_size # [BOS] + max_decode_len assert decoded.size(1) <= g.hparams.max_decode_len + 1 # Always start with BOS assert decoded[:, 0].equal( torch.tensor([g.preprocessor.word_to_id(BOS)] * batch_size))
def agent_simple_words(): preprocessor = SpacyPreprocessor( [PAD, UNK, "action", "1", "2", "3", "examine", "cookbook", "table"]) return Agent( GraphUpdaterObsGen().graph_updater, GATADoubleDQN().action_selector, preprocessor, )
def agent(): graph_updater_obs_gen = GraphUpdaterObsGen( word_vocab_path="vocabs/word_vocab.txt") return Agent( graph_updater_obs_gen.graph_updater, GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector, graph_updater_obs_gen.preprocessor, )
def eps_greedy_agent(): graph_updater_obs_gen = GraphUpdaterObsGen( word_vocab_path="vocabs/word_vocab.txt") return EpsilonGreedyAgent( graph_updater_obs_gen.graph_updater, GATADoubleDQN(word_vocab_path="vocabs/word_vocab.txt").action_selector, graph_updater_obs_gen.preprocessor, 1.0, 0.1, 20, )
def test_graph_updater_obs_gen_default_init(): g = GraphUpdaterObsGen() default_word_vocab = [PAD, UNK, BOS, EOS] assert g.preprocessor.word_vocab == default_word_vocab assert g.graph_updater.word_embeddings[0].weight.size() == ( len(default_word_vocab), g.hparams.word_emb_dim, ) # default node_vocab = ['node'] assert g.graph_updater.node_name_word_ids.size() == (len(g.node_vocab), 1) assert g.graph_updater.node_name_mask.size() == (len(g.node_vocab), 1) # default relation_vocab = ['relation', 'relation reverse'] assert g.graph_updater.rel_name_word_ids.size() == (len(g.relation_vocab), 2) assert g.graph_updater.rel_name_mask.size() == (len(g.relation_vocab), 2)
def test_learning_rate_warmup(step, multiplier): g = GraphUpdaterObsGen(steps_for_lr_warmup=16) assert g.learning_rate_warmup(step) == multiplier
def main(cfg: DictConfig) -> None: print(f"Training with the following config:\n{OmegaConf.to_yaml(cfg)}") # seed pl.seed_everything(cfg.seed) # trainer trainer_config = OmegaConf.to_container(cfg.pl_trainer, resolve=True) assert isinstance(trainer_config, dict) trainer_config["logger"] = instantiate( cfg.logger) if "logger" in cfg else True val_monitor = "val_avg_game_normalized_rewards" train_monitor = "train_avg_game_normalized_rewards" trainer_config["callbacks"] = [ RLEarlyStopping( val_monitor, train_monitor, cfg.train.early_stop_threshold, patience=cfg.train.early_stop_patience, ), EqualModelCheckpoint( monitor=val_monitor, mode="max", filename="{epoch}-{step}-{val_avg_game_normalized_rewards:.2f}", ), ] if isinstance(trainer_config["logger"], WandbLogger): trainer_config["callbacks"].append(WandbSaveCallback()) trainer = pl.Trainer(**trainer_config) # instantiate the lightning module if not cfg.eval.test_only: lm_model_config = OmegaConf.to_container(cfg.model, resolve=True) assert isinstance(lm_model_config, dict) if cfg.model.pretrained_graph_updater is not None: graph_updater_obs_gen = GraphUpdaterObsGen.load_from_checkpoint( to_absolute_path(cfg.model.pretrained_graph_updater.ckpt_path), word_vocab_path=cfg.model.pretrained_graph_updater. word_vocab_path, node_vocab_path=cfg.model.pretrained_graph_updater. node_vocab_path, relation_vocab_path=( cfg.model.pretrained_graph_updater.relation_vocab_path), ) lm_model_config[ "pretrained_graph_updater"] = graph_updater_obs_gen.graph_updater lm = GATADoubleDQN(**lm_model_config, **cfg.train, **cfg.data) # fit trainer.fit(lm) # test trainer.test() else: assert (cfg.eval.checkpoint_path is not None), "missing checkpoint path for testing" parsed = urlparse(cfg.eval.checkpoint_path) if parsed.scheme == "": # local path ckpt_path = to_absolute_path(cfg.eval.checkpoint_path) else: # remote path ckpt_path = cfg.eval.checkpoint_path model = GATADoubleDQN.load_from_checkpoint( ckpt_path, base_data_dir=cfg.data.base_data_dir, word_vocab_path=cfg.model.word_vocab_path, node_vocab_path=cfg.model.node_vocab_path, relation_vocab_path=cfg.model.relation_vocab_path, ) trainer.test(model=model)