示例#1
0
def test_action_selector_choose_max_q(action_scores, action_mask,
                                      max_q_actions):
    num_words = 10
    num_nodes = 5
    num_relations = 10
    action_selector = ActionSelector(
        12,
        num_words,
        16,
        num_nodes,
        12,
        num_relations,
        12,
        1,
        1,
        3,
        1,
        1,
        1,
        1,
        torch.randint(num_words, (num_nodes, 3)),
        increasing_mask(num_nodes, 3),
        torch.randint(num_words, (num_relations, 3)),
        increasing_mask(num_relations, 3),
    )
    assert action_selector.select_max_q(action_scores,
                                        action_mask).equal(max_q_actions)
示例#2
0
def test_gata_double_dqn_forward(
    batch_size,
    obs_len,
    prev_action_len,
    num_action_cands,
    action_cand_len,
):
    gata_ddqn = GATADoubleDQN()
    results = gata_ddqn(
        torch.randint(gata_ddqn.num_words, (batch_size, obs_len)),
        increasing_mask(batch_size, obs_len),
        torch.randint(gata_ddqn.num_words, (batch_size, prev_action_len)),
        increasing_mask(batch_size, prev_action_len),
        torch.rand(batch_size, gata_ddqn.hparams.hidden_dim),
        torch.randint(gata_ddqn.num_words,
                      (batch_size, num_action_cands, action_cand_len)),
        increasing_mask(batch_size * num_action_cands,
                        action_cand_len).view(batch_size, num_action_cands,
                                              action_cand_len),
        increasing_mask(batch_size, num_action_cands),
    )
    assert results["action_scores"].size() == (batch_size, num_action_cands)
    assert results["rnn_curr_hidden"].size() == (
        batch_size,
        gata_ddqn.hparams.hidden_dim,
    )
    assert results["current_graph"].size() == (
        batch_size,
        gata_ddqn.num_relations,
        gata_ddqn.num_nodes,
        gata_ddqn.num_nodes,
    )
示例#3
0
def test_action_selector(
    hidden_dim,
    num_words,
    word_emb_dim,
    num_nodes,
    node_emb_dim,
    num_relations,
    relation_emb_dim,
    text_encoder_num_blocks,
    text_encoder_num_conv_layers,
    text_encoder_kernel_size,
    text_encoder_num_heads,
    graph_encoder_num_cov_layers,
    graph_encoder_num_bases,
    action_scorer_num_heads,
    batch_size,
    obs_len,
    num_action_cands,
    action_cand_len,
):
    action_selector = ActionSelector(
        hidden_dim,
        num_words,
        word_emb_dim,
        num_nodes,
        node_emb_dim,
        num_relations,
        relation_emb_dim,
        text_encoder_num_blocks,
        text_encoder_num_conv_layers,
        text_encoder_kernel_size,
        text_encoder_num_heads,
        graph_encoder_num_cov_layers,
        graph_encoder_num_bases,
        action_scorer_num_heads,
        torch.randint(num_words, (num_nodes, 3)),
        increasing_mask(num_nodes, 3),
        torch.randint(num_words, (num_relations, 3)),
        increasing_mask(num_relations, 3),
    )
    assert (action_selector(
        torch.randint(num_words, (batch_size, obs_len)),
        increasing_mask(batch_size, obs_len),
        torch.rand(batch_size, num_relations, num_nodes, num_nodes),
        torch.randint(num_words,
                      (batch_size, num_action_cands, action_cand_len)),
        increasing_mask(batch_size * num_action_cands,
                        action_cand_len).view(batch_size, num_action_cands,
                                              action_cand_len),
        increasing_mask(batch_size, num_action_cands),
    ).size() == (batch_size, num_action_cands))
示例#4
0
        def __init__(self):
            super().__init__()
            self.word_embeddings = nn.Embedding(num_words, hidden_dim)
            self.text_encoder = TextEncoder(1, 1, 1, hidden_dim, 1)
            self.graph_encoder = GraphEncoder(
                hidden_dim + node_emb_dim,
                hidden_dim + rel_emb_dim,
                num_relations,
                [hidden_dim],
                1,
            )
            self.node_embeddings = nn.Embedding(num_node, node_emb_dim)
            self.relation_embeddings = nn.Embedding(num_relations, rel_emb_dim)

            self.node_name_word_ids = torch.randint(num_words, (num_node, 3))
            self.node_name_mask = increasing_mask(num_node, 3)
            self.rel_name_word_ids = torch.randint(num_words,
                                                   (num_relations, 2))
            self.rel_name_mask = increasing_mask(num_relations, 2)
示例#5
0
def test_action_scorer(
    hidden_dim,
    num_heads,
    batch_size,
    num_action_cands,
    action_cand_len,
    num_node,
    obs_len,
):
    action_scorer = ActionScorer(hidden_dim, num_heads)

    assert (action_scorer(
        torch.rand(batch_size, num_action_cands, action_cand_len, hidden_dim),
        increasing_mask(num_action_cands,
                        action_cand_len,
                        start_with_zero=True).unsqueeze(0).expand(
                            batch_size, -1, -1),
        torch.randint(2, (batch_size, num_action_cands)),
        torch.rand(batch_size, obs_len, hidden_dim),
        torch.rand(batch_size, num_node, hidden_dim),
        increasing_mask(batch_size, obs_len),
    ).size() == (batch_size, num_action_cands))
示例#6
0
def test_text_enc_block(num_conv_layers, kernel_size, hidden_dim, num_heads,
                        batch_size, seq_len):
    text_enc_block = TextEncoderBlock(num_conv_layers, kernel_size, hidden_dim,
                                      num_heads)
    # random tensors and increasing masks
    assert text_enc_block(
        torch.rand(batch_size, seq_len, hidden_dim),
        increasing_mask(batch_size, seq_len),
    ).size() == (
        batch_size,
        seq_len,
        hidden_dim,
    )
示例#7
0
def test_encoder_mixin(
    num_words,
    hidden_dim,
    node_emb_dim,
    rel_emb_dim,
    num_node,
    num_relations,
    batch_size,
    seq_len,
):
    class TestEncoder(EncoderMixin, nn.Module):
        def __init__(self):
            super().__init__()
            self.word_embeddings = nn.Embedding(num_words, hidden_dim)
            self.text_encoder = TextEncoder(1, 1, 1, hidden_dim, 1)
            self.graph_encoder = GraphEncoder(
                hidden_dim + node_emb_dim,
                hidden_dim + rel_emb_dim,
                num_relations,
                [hidden_dim],
                1,
            )
            self.node_embeddings = nn.Embedding(num_node, node_emb_dim)
            self.relation_embeddings = nn.Embedding(num_relations, rel_emb_dim)

            self.node_name_word_ids = torch.randint(num_words, (num_node, 3))
            self.node_name_mask = increasing_mask(num_node, 3)
            self.rel_name_word_ids = torch.randint(num_words,
                                                   (num_relations, 2))
            self.rel_name_mask = increasing_mask(num_relations, 2)

    te = TestEncoder()
    assert (te.encode_text(
        torch.randint(num_words, (batch_size, seq_len)),
        increasing_mask(batch_size, seq_len),
    ).size() == (batch_size, seq_len, hidden_dim))
    assert te.get_node_features().size() == (num_node,
                                             hidden_dim + node_emb_dim)
    assert te.get_relation_features().size() == (
        num_relations,
        hidden_dim + rel_emb_dim,
    )
    assert te.encode_graph(
        torch.rand(batch_size, num_relations, num_node,
                   num_node)).size() == (batch_size, num_node, hidden_dim)
示例#8
0
def test_increasing_mask():
    assert increasing_mask(3, 2).equal(
        torch.tensor([[1, 0], [1, 1], [1, 1]]).float())
    assert increasing_mask(3, 2, start_with_zero=True).equal(
        torch.tensor([[0, 0], [1, 0], [1, 1]]).float())