Esempio n. 1
0
def test_graph_updater_obs_gen_forward(batch_size, obs_len, prev_action_len,
                                       rnn_prev_hidden, training):
    g = GraphUpdaterObsGen()
    g.train(training)
    episode_data = {
        "obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "obs_mask":
        torch.randint(2, (batch_size, obs_len)).float(),
        "prev_action_word_ids":
        torch.randint(g.num_words, (batch_size, prev_action_len)),
        "prev_action_mask":
        torch.randint(2, (batch_size, prev_action_len)).float(),
        "groundtruth_obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
    }
    results = g(
        episode_data,
        rnn_prev_hidden=torch.rand(batch_size, g.hparams.hidden_dim)
        if rnn_prev_hidden else None,
    )
    assert results["h_t"].size() == (batch_size, g.hparams.hidden_dim)
    assert results["batch_loss"].size() == (batch_size, )
    if not training:
        assert results["pred_obs_word_ids"].size() == (batch_size, obs_len)
        # decoded_obs_word_ids has variable lengths
        assert results["decoded_obs_word_ids"].size(0) == batch_size
        assert results["decoded_obs_word_ids"].ndim == 2
Esempio n. 2
0
def test_graph_updater_obs_gen_process_batch(batch_size, obs_len,
                                             prev_action_len, max_episode_len,
                                             hidden, training):
    g = GraphUpdaterObsGen()
    g.train(training)
    batch = [{
        "obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "obs_mask":
        torch.randint(2, (batch_size, obs_len)).float(),
        "prev_action_word_ids":
        torch.randint(g.num_words, (batch_size, prev_action_len)),
        "prev_action_mask":
        torch.randint(2, (batch_size, prev_action_len)).float(),
        "groundtruth_obs_word_ids":
        torch.randint(g.num_words, (batch_size, obs_len)),
        "step_mask":
        torch.randint(2, (batch_size, )).float(),
    } for _ in range(max_episode_len)]
    h_t = torch.rand(batch_size, g.hparams.hidden_dim) if hidden else None
    results = g.process_batch(batch, h_t=h_t)
    assert len(results["losses"]) == max_episode_len
    assert all(loss.ndim == 0 for loss in results["losses"])
    assert len(results["hiddens"]) == max_episode_len
    assert all(hidden.size() == (batch_size, g.hparams.hidden_dim)
               for hidden in results["hiddens"])
    if not training:
        assert len(results["preds"]) == max_episode_len
        assert all(pred.size() == (batch_size, obs_len)
                   for pred in results["preds"])
        assert len(results["decoded"]) == max_episode_len
        assert all(dec.ndim == 2 for dec in results["decoded"])
        assert all(dec.size(0) == batch_size for dec in results["decoded"])
        assert len(results["f1s"]) <= max_episode_len * batch_size
        assert all(f1.ndim == 0 for f1 in results["f1s"])