Exemplo n.º 1
0
def test_all_attention_works_the_same(attention_type: str):
    module_cls = Attention.by_name(attention_type)

    vector = torch.FloatTensor([[-7, -8, -9]])
    matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]])

    if module_cls in {BilinearAttention, AdditiveAttention, LinearAttention}:
        module = module_cls(vector.size(-1), matrix.size(-1))
    else:
        module = module_cls()

    output = module(vector, matrix)
    assert tuple(output.size()) == (1, 2)
Exemplo n.º 2
0
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name("dot_product")(),
            add_action_bias=False,
        )

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            "e": {
                "global": (
                    torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                    torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]),
                    [0, 1, 2],
                ),
                "linked": (
                    torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
                    torch.FloatTensor([[3, 3], [4, 4]]),
                    [3, 4],
                ),
            },
            "d": {
                "global":
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                "linked": (
                    torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6],
                                       [-0.7, -0.8, -0.9]]),
                    torch.FloatTensor([[5, 5], [6, 6], [7, 7]]),
                    [1, 2, 3],
                ),
            },
        }
        grammar_state = [
            GrammarStatelet([nonterminal], valid_actions, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ["e", "d", "e"])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [
            [
                ("e -> f", False, None),
                ("e -> g", True, None),
                ("e -> h", True, None),
                ("e -> i", True, None),
                ("e -> j", True, None),
            ],
            [
                ("d -> q", True, None),
                ("d -> g", True, None),
                ("d -> h", True, None),
                ("d -> i", True, None),
            ],
        ]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(
                    hidden_state[i],
                    memory_cell[i],
                    previous_action_embedding[i],
                    attended_question[i],
                    self.encoder_outputs,
                    self.encoder_output_mask,
                ))
        self.state = GrammarBasedState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions,
        )
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(encoder_output_dim=2,
                                                    action_embedding_dim=2,
                                                    input_attention=Attention.by_name('dot_product')(),
                                                    num_start_types=3,
                                                    add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor([[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor([[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
                'e': {
                        'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                                   torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]),
                                   [0, 1, 2]),
                        'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                                   torch.FloatTensor([[3, 3], [4, 4]]),
                                   [3, 4])
                },
                'd': {
                        'global': (torch.FloatTensor([[0, 0]]),
                                   torch.FloatTensor([[-1, -1]]),
                                   [0]),
                        'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9]]),
                                   torch.FloatTensor([[5, 5], [6, 6], [7, 7]]),
                                   [1, 2, 3])
                }
        }
        grammar_state = [GrammarStatelet([nonterminal], valid_actions, is_nonterminal)
                         for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[('e -> f', False, None),
                                  ('e -> g', True, None),
                                  ('e -> h', True, None),
                                  ('e -> i', True, None),
                                  ('e -> j', True, None)],
                                 [('d -> q', True, None),
                                  ('d -> g', True, None),
                                  ('d -> h', True, None),
                                  ('d -> i', True, None)]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(RnnStatelet(hidden_state[i],
                                         memory_cell[i],
                                         previous_action_embedding[i],
                                         attended_question[i],
                                         self.encoder_outputs,
                                         self.encoder_output_mask))
        self.state = GrammarBasedState(batch_indices=batch_indices,
                                       action_history=action_history,
                                       score=score,
                                       rnn_state=rnn_state,
                                       grammar_state=grammar_state,
                                       possible_actions=self.possible_actions)
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name('dot_product')(),
            num_start_types=3,
            add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            'e': {
                'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                           torch.FloatTensor([[-1, -1], [-2, -2],
                                              [-3, -3]]), [0, 1, 2]),
                'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                           torch.FloatTensor([[3, 3], [4, 4]]), [3, 4])
            },
            'd': {
                'global':
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6],
                                              [-.7, -.8, -.9]]),
                           torch.FloatTensor([[5, 5], [6, 6], [7,
                                                               7]]), [1, 2, 3])
            }
        }
        grammar_state = [
            GrammarState([nonterminal], {}, valid_actions, {}, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('d -> q', True, None),
                                     ('d -> g', True, None),
                                     ('d -> h', True, None),
                                     ('d -> i', True, None)
                                 ]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnState(hidden_state[i], memory_cell[i],
                         previous_action_embedding[i], attended_question[i],
                         self.encoder_outputs, self.encoder_output_mask))
        self.state = GrammarBasedDecoderState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions)