def test_action_selector_choose_max_q(action_scores, action_mask, max_q_actions): num_words = 10 num_nodes = 5 num_relations = 10 action_selector = ActionSelector( 12, num_words, 16, num_nodes, 12, num_relations, 12, 1, 1, 3, 1, 1, 1, 1, torch.randint(num_words, (num_nodes, 3)), increasing_mask(num_nodes, 3), torch.randint(num_words, (num_relations, 3)), increasing_mask(num_relations, 3), ) assert action_selector.select_max_q(action_scores, action_mask).equal(max_q_actions)
def test_gata_double_dqn_forward( batch_size, obs_len, prev_action_len, num_action_cands, action_cand_len, ): gata_ddqn = GATADoubleDQN() results = gata_ddqn( torch.randint(gata_ddqn.num_words, (batch_size, obs_len)), increasing_mask(batch_size, obs_len), torch.randint(gata_ddqn.num_words, (batch_size, prev_action_len)), increasing_mask(batch_size, prev_action_len), torch.rand(batch_size, gata_ddqn.hparams.hidden_dim), torch.randint(gata_ddqn.num_words, (batch_size, num_action_cands, action_cand_len)), increasing_mask(batch_size * num_action_cands, action_cand_len).view(batch_size, num_action_cands, action_cand_len), increasing_mask(batch_size, num_action_cands), ) assert results["action_scores"].size() == (batch_size, num_action_cands) assert results["rnn_curr_hidden"].size() == ( batch_size, gata_ddqn.hparams.hidden_dim, ) assert results["current_graph"].size() == ( batch_size, gata_ddqn.num_relations, gata_ddqn.num_nodes, gata_ddqn.num_nodes, )
def test_action_selector( hidden_dim, num_words, word_emb_dim, num_nodes, node_emb_dim, num_relations, relation_emb_dim, text_encoder_num_blocks, text_encoder_num_conv_layers, text_encoder_kernel_size, text_encoder_num_heads, graph_encoder_num_cov_layers, graph_encoder_num_bases, action_scorer_num_heads, batch_size, obs_len, num_action_cands, action_cand_len, ): action_selector = ActionSelector( hidden_dim, num_words, word_emb_dim, num_nodes, node_emb_dim, num_relations, relation_emb_dim, text_encoder_num_blocks, text_encoder_num_conv_layers, text_encoder_kernel_size, text_encoder_num_heads, graph_encoder_num_cov_layers, graph_encoder_num_bases, action_scorer_num_heads, torch.randint(num_words, (num_nodes, 3)), increasing_mask(num_nodes, 3), torch.randint(num_words, (num_relations, 3)), increasing_mask(num_relations, 3), ) assert (action_selector( torch.randint(num_words, (batch_size, obs_len)), increasing_mask(batch_size, obs_len), torch.rand(batch_size, num_relations, num_nodes, num_nodes), torch.randint(num_words, (batch_size, num_action_cands, action_cand_len)), increasing_mask(batch_size * num_action_cands, action_cand_len).view(batch_size, num_action_cands, action_cand_len), increasing_mask(batch_size, num_action_cands), ).size() == (batch_size, num_action_cands))
def __init__(self): super().__init__() self.word_embeddings = nn.Embedding(num_words, hidden_dim) self.text_encoder = TextEncoder(1, 1, 1, hidden_dim, 1) self.graph_encoder = GraphEncoder( hidden_dim + node_emb_dim, hidden_dim + rel_emb_dim, num_relations, [hidden_dim], 1, ) self.node_embeddings = nn.Embedding(num_node, node_emb_dim) self.relation_embeddings = nn.Embedding(num_relations, rel_emb_dim) self.node_name_word_ids = torch.randint(num_words, (num_node, 3)) self.node_name_mask = increasing_mask(num_node, 3) self.rel_name_word_ids = torch.randint(num_words, (num_relations, 2)) self.rel_name_mask = increasing_mask(num_relations, 2)
def test_action_scorer( hidden_dim, num_heads, batch_size, num_action_cands, action_cand_len, num_node, obs_len, ): action_scorer = ActionScorer(hidden_dim, num_heads) assert (action_scorer( torch.rand(batch_size, num_action_cands, action_cand_len, hidden_dim), increasing_mask(num_action_cands, action_cand_len, start_with_zero=True).unsqueeze(0).expand( batch_size, -1, -1), torch.randint(2, (batch_size, num_action_cands)), torch.rand(batch_size, obs_len, hidden_dim), torch.rand(batch_size, num_node, hidden_dim), increasing_mask(batch_size, obs_len), ).size() == (batch_size, num_action_cands))
def test_text_enc_block(num_conv_layers, kernel_size, hidden_dim, num_heads, batch_size, seq_len): text_enc_block = TextEncoderBlock(num_conv_layers, kernel_size, hidden_dim, num_heads) # random tensors and increasing masks assert text_enc_block( torch.rand(batch_size, seq_len, hidden_dim), increasing_mask(batch_size, seq_len), ).size() == ( batch_size, seq_len, hidden_dim, )
def test_encoder_mixin( num_words, hidden_dim, node_emb_dim, rel_emb_dim, num_node, num_relations, batch_size, seq_len, ): class TestEncoder(EncoderMixin, nn.Module): def __init__(self): super().__init__() self.word_embeddings = nn.Embedding(num_words, hidden_dim) self.text_encoder = TextEncoder(1, 1, 1, hidden_dim, 1) self.graph_encoder = GraphEncoder( hidden_dim + node_emb_dim, hidden_dim + rel_emb_dim, num_relations, [hidden_dim], 1, ) self.node_embeddings = nn.Embedding(num_node, node_emb_dim) self.relation_embeddings = nn.Embedding(num_relations, rel_emb_dim) self.node_name_word_ids = torch.randint(num_words, (num_node, 3)) self.node_name_mask = increasing_mask(num_node, 3) self.rel_name_word_ids = torch.randint(num_words, (num_relations, 2)) self.rel_name_mask = increasing_mask(num_relations, 2) te = TestEncoder() assert (te.encode_text( torch.randint(num_words, (batch_size, seq_len)), increasing_mask(batch_size, seq_len), ).size() == (batch_size, seq_len, hidden_dim)) assert te.get_node_features().size() == (num_node, hidden_dim + node_emb_dim) assert te.get_relation_features().size() == ( num_relations, hidden_dim + rel_emb_dim, ) assert te.encode_graph( torch.rand(batch_size, num_relations, num_node, num_node)).size() == (batch_size, num_node, hidden_dim)
def test_increasing_mask(): assert increasing_mask(3, 2).equal( torch.tensor([[1, 0], [1, 1], [1, 1]]).float()) assert increasing_mask(3, 2, start_with_zero=True).equal( torch.tensor([[0, 0], [1, 0], [1, 1]]).float())