예제 #1
0
 def test_empty(self):
     stack = TransitionParserStack.empty(1, 4)
     stack.update(LEFT, VEC1)
     assert stack.tapes.tolist() == [[[0., 0., 0., 0.]]]
     stack = TransitionParserStack.empty(1, 4)
     stack.update(RIGHT, VEC1)
     assert stack.tapes.tolist() == [[[0., 0., 0., 0.]]]
예제 #2
0
 def test_shift_shift_right(self):
     stack = TransitionParserStack.empty(1, 4)
     stack.update(SHIFT, VEC1)
     assert stack.tapes.tolist() == [[[1., 0., 0., 0.]]]
     stack.update(SHIFT, VEC2)
     assert stack.tapes.tolist() == [[[0., 1., 0., 0.], [1., 0., 0., 0.]]]
     stack.update(RIGHT, VEC1)
     assert stack.tapes.tolist() == [[[1., 0., 0., 0.], [0., 0., 0., 0.],
                                      [0., 0., 0., 0.]]]
예제 #3
0
 def test_superpos_empty(self):
     stack = TransitionParserStack.empty(1, 4, None)
     policy = torch.ones(1, 3) / 3
     stack.update(SHIFT, VEC1)
     stack.update(SHIFT, VEC2)
     stack.update(policy, VEC1)
     expected = [[
         [2 / 3, 1 / 3, 0, 0],
         [0, 1 / 3, 0, 0],
         [1 / 3, 0, 0, 0],
     ]]
     torch.testing.assert_allclose(stack.tapes.tolist(), expected)
예제 #4
0
    def forward(self, tokens, actions):
        batch_size, num_actions = actions.size()
        actions, act_mask = self._get_actions_and_mask(actions)
        buf_mask = self._get_buffer_mask(tokens)

        buf = self.embedder(tokens)
        stack = TransitionParserStack.empty(batch_size, self.stack_dim, device=buf.device)
        summary = torch.zeros(batch_size, self.summary_dim, device=buf.device)
        self.controller.reset(batch_size, device=buf.device)

        buf_pos = self._init_buffer_pos(buf_mask)
        policies = []

        for _ in range(num_actions):  # num_actions = 2 * num_tokens - 1.
            
            # Compute the policy using the controller.
            buf_head = torch.squeeze(buf_pos @ buf, dim=1)
            state = self.controller(buf_head, summary)
            policy = torch.softmax(self.policy(state), dim=-1)
            policies.append(policy)

            # Update the stack, buffer, and summary.
            stack_policy = self._get_stack_policy(policy)
            buf_pos = self._shift_buffer_pos(buf_pos, stack_policy, buf_mask)
            stack.update(stack_policy, buf_head)
            summary = self._summarize(stack)

        # Record the actions in the stack space and then map to label space.
        policies = torch.stack(policies, dim=1)
        output_dict = {"policies": policies}
        policies = self.aligner(policies)

        if actions is not None:
            loss = sequence_cross_entropy_with_logits(policies, actions, act_mask)
            output_dict["loss"] = loss
            self.accuracy(policies, actions, act_mask)

        return output_dict
예제 #5
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 controller: StackController,
                 summary_size: int,
                 hard: bool = False):

        super().__init__(vocab)
        self.embedder = text_field_embedder
        self.controller = controller
        self.accuracy = CategoricalAccuracy()

        self.stack_dim = self.embedder.get_output_dim()
        self.summary_size = summary_size
        self.summary_dim = controller.get_summary_dim()
        assert self.stack_dim * summary_size == self.summary_dim

        output_dim = controller.get_output_dim()
        num_actions = TransitionParserStack.get_num_actions()
        self.policy = torch.nn.Linear(output_dim, num_actions)
        self.hard = hard

        # We use this linear transformation to align stack action space to label action space.
        self.aligner = torch.nn.Linear(num_actions, num_actions)
예제 #6
0
 def test_returns(self):
     stack = TransitionParserStack.empty(1, 4, None)
     tapes = stack.update(SHIFT, VEC1)
     assert tapes is stack.tapes
예제 #7
0
 def test_get_num_actions(self):
     assert TransitionParserStack.get_num_actions() == 3
예제 #8
0
 def test_initialize(self):
     stack = TransitionParserStack.empty(1, 4)
     assert stack.tapes.tolist() == [[]]