def test_empty(self): stack = TransitionParserStack.empty(1, 4) stack.update(LEFT, VEC1) assert stack.tapes.tolist() == [[[0., 0., 0., 0.]]] stack = TransitionParserStack.empty(1, 4) stack.update(RIGHT, VEC1) assert stack.tapes.tolist() == [[[0., 0., 0., 0.]]]
def test_shift_shift_right(self): stack = TransitionParserStack.empty(1, 4) stack.update(SHIFT, VEC1) assert stack.tapes.tolist() == [[[1., 0., 0., 0.]]] stack.update(SHIFT, VEC2) assert stack.tapes.tolist() == [[[0., 1., 0., 0.], [1., 0., 0., 0.]]] stack.update(RIGHT, VEC1) assert stack.tapes.tolist() == [[[1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]
def test_superpos_empty(self): stack = TransitionParserStack.empty(1, 4, None) policy = torch.ones(1, 3) / 3 stack.update(SHIFT, VEC1) stack.update(SHIFT, VEC2) stack.update(policy, VEC1) expected = [[ [2 / 3, 1 / 3, 0, 0], [0, 1 / 3, 0, 0], [1 / 3, 0, 0, 0], ]] torch.testing.assert_allclose(stack.tapes.tolist(), expected)
def forward(self, tokens, actions): batch_size, num_actions = actions.size() actions, act_mask = self._get_actions_and_mask(actions) buf_mask = self._get_buffer_mask(tokens) buf = self.embedder(tokens) stack = TransitionParserStack.empty(batch_size, self.stack_dim, device=buf.device) summary = torch.zeros(batch_size, self.summary_dim, device=buf.device) self.controller.reset(batch_size, device=buf.device) buf_pos = self._init_buffer_pos(buf_mask) policies = [] for _ in range(num_actions): # num_actions = 2 * num_tokens - 1. # Compute the policy using the controller. buf_head = torch.squeeze(buf_pos @ buf, dim=1) state = self.controller(buf_head, summary) policy = torch.softmax(self.policy(state), dim=-1) policies.append(policy) # Update the stack, buffer, and summary. stack_policy = self._get_stack_policy(policy) buf_pos = self._shift_buffer_pos(buf_pos, stack_policy, buf_mask) stack.update(stack_policy, buf_head) summary = self._summarize(stack) # Record the actions in the stack space and then map to label space. policies = torch.stack(policies, dim=1) output_dict = {"policies": policies} policies = self.aligner(policies) if actions is not None: loss = sequence_cross_entropy_with_logits(policies, actions, act_mask) output_dict["loss"] = loss self.accuracy(policies, actions, act_mask) return output_dict
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, controller: StackController, summary_size: int, hard: bool = False): super().__init__(vocab) self.embedder = text_field_embedder self.controller = controller self.accuracy = CategoricalAccuracy() self.stack_dim = self.embedder.get_output_dim() self.summary_size = summary_size self.summary_dim = controller.get_summary_dim() assert self.stack_dim * summary_size == self.summary_dim output_dim = controller.get_output_dim() num_actions = TransitionParserStack.get_num_actions() self.policy = torch.nn.Linear(output_dim, num_actions) self.hard = hard # We use this linear transformation to align stack action space to label action space. self.aligner = torch.nn.Linear(num_actions, num_actions)
def test_returns(self): stack = TransitionParserStack.empty(1, 4, None) tapes = stack.update(SHIFT, VEC1) assert tapes is stack.tapes
def test_get_num_actions(self): assert TransitionParserStack.get_num_actions() == 3
def test_initialize(self): stack = TransitionParserStack.empty(1, 4) assert stack.tapes.tolist() == [[]]