Example #1
0
    def test_StackLSTM(self):
        lstm_dim = 100
        lstm_num_layers = 2

        element_root = Element("Root")
        element_node = Element("Node")

        lstm = nn.LSTM(lstm_dim, lstm_dim, num_layers=lstm_num_layers)
        initial_state = (
            torch.zeros(lstm_num_layers, 1, lstm_dim),
            torch.zeros(lstm_num_layers, 1, lstm_dim),
        )
        empty_embedding = torch.zeros(1, lstm_dim)
        stackLSTM = StackLSTM(lstm, initial_state, empty_embedding)

        stackLSTM.push(empty_embedding, element_node)
        self.assertEqual(len(stackLSTM), 1)
        self.assertEqual(stackLSTM.element_from_top(0), element_node)
        self.assertEqual(stackLSTM.element_from_top(1), element_root)
        self.assertEqual(stackLSTM.embedding().shape, empty_embedding.shape)

        self.assertEqual(stackLSTM.pop()[1], element_node)
        self.assertEqual(len(stackLSTM), 0)
        self.assertEqual(stackLSTM.element_from_top(0), element_root)
        self.assertTrue(torch.equal(stackLSTM.embedding(), empty_embedding))
Example #2
0
    def push_action(self, state: ParserState, target_action_idx: int) -> None:
        """Used for updating the state with a target next action

        Args:
            state (ParserState): The state of the stack, buffer and action
            target_action_idx (int): Index of the action to process
        """

        # Update action_stackrnn
        action_embedding = self.actions_lookup(
            cuda_utils.Variable(torch.LongTensor([target_action_idx]))
        )
        state.action_stackrnn.push(action_embedding, Element(target_action_idx))

        # Update stack_stackrnn
        if target_action_idx == self.shift_idx:
            # To SHIFT,
            # 1. Pop T from buffer
            # 2. Push T into stack
            state.is_open_NT.append(False)
            token_embedding, token = state.buffer_stackrnn.pop()
            state.stack_stackrnn.push(token_embedding, Element(token))

        elif target_action_idx == self.reduce_idx:
            # To REDUCE
            # 1. Pop Ts from stack until hit NT
            # 2. Pop the open NT from stack and close it
            # 3. Compute compositionalRep and push into stack
            state.num_open_NT -= 1
            popped_rep = []
            nt_tree = []

            while not state.is_open_NT[-1]:
                assert len(state.stack_stackrnn) > 0, "How come stack is empty!"
                state.is_open_NT.pop()
                top_of_stack = state.stack_stackrnn.pop()
                popped_rep.append(top_of_stack[0])
                nt_tree.append(top_of_stack[1])

            # pop the open NT and close it
            top_of_stack = state.stack_stackrnn.pop()
            popped_rep.append(top_of_stack[0])
            nt_tree.append(top_of_stack[1])

            state.is_open_NT.pop()
            state.is_open_NT.append(False)

            compostional_rep = self.p_compositional(popped_rep)
            combinedElement = Element(nt_tree)

            state.stack_stackrnn.push(compostional_rep, combinedElement)

        elif target_action_idx in self.valid_NT_idxs:

            # if this is root prediction and if that root is one
            # of the unsupported intents
            if (
                len(state.predicted_actions_idx) == 1
                and target_action_idx in self.ignore_subNTs_roots
            ):
                state.found_unsupported = True

            state.is_open_NT.append(True)
            state.num_open_NT += 1
            state.stack_stackrnn.push(action_embedding, Element(target_action_idx))
        else:
            assert "not a valid action: {}".format(
                self.actions_vocab.itos[target_action_idx]
            )
Example #3
0
    def forward(
        self,
        tokens: torch.Tensor,
        seq_lens: torch.Tensor,
        dict_feat: Optional[Tuple[torch.Tensor, ...]] = None,
        actions: Optional[List[List[int]]] = None,
        contextual_token_embeddings: Optional[torch.Tensor] = None,
        beam_size=1,
        top_k=1,
    ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """RNNG forward function.

        Args:
            tokens (torch.Tensor): list of tokens
            seq_lens (torch.Tensor): list of sequence lengths
            dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer
                features for each token
            actions (Optional[List[List[int]]]): Used only during training.
                Oracle actions for the instances.

        Returns:
            list of top k tuple of predicted actions tensor and corresponding scores tensor.
            Tensor shape:
            (batch_size, action_length)
            (batch_size, action_length, number_of_actions)
        """

        if self.stage != Stage.TEST:
            beam_size = 1
            top_k = 1

        if self.training:
            assert actions is not None, "actions must be provided for training"
            actions_idx_rev = list(reversed(actions[0]))
        else:
            torch.manual_seed(0)

        beam_size = max(beam_size, 1)

        # Reverse the order of input tokens.
        tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1])

        # Aggregate inputs for embedding module.
        embedding_input = [tokens]
        if dict_feat is not None:
            embedding_input.append(dict_feat)
        if contextual_token_embeddings is not None:
            embedding_input.append(contextual_token_embeddings)

        # Embed and reverse the order of tokens.
        token_embeddings = self.embedding(*embedding_input)
        token_embeddings = torch.flip(token_embeddings, [len(tokens.size()) - 1])

        # Batch size is always = 1. So we squeeze the batch_size dimension.
        token_embeddings = token_embeddings.squeeze(0)
        tokens_list_rev = tokens_list_rev.squeeze(0)

        initial_state = ParserState(self)
        for i in range(token_embeddings.size()[0]):
            token_embedding = token_embeddings[i].unsqueeze(0)
            tok = tokens_list_rev[i]
            initial_state.buffer_stackrnn.push(token_embedding, Element(tok))

        beam = [initial_state]
        while beam and any(not state.finished() for state in beam):
            # Stores plans for expansion as (score, state, action)
            plans: List[Tuple[float, ParserState, int]] = []
            # Expand current beam states
            for state in beam:
                # Keep terminal states
                if state.finished():
                    plans.append((state.neg_prob, state, -1))
                    continue

                #  translating Expression p_t = affine_transform({pbias, S,
                #  stack_summary, B, buffer_summary, A, action_summary});
                stack = state.stack_stackrnn
                stack_summary = stack.embedding()
                action_summary = state.action_stackrnn.embedding()
                buffer_summary = state.buffer_stackrnn.embedding()
                if self.dropout_layer.p > 0:
                    stack_summary = self.dropout_layer(stack_summary)
                    action_summary = self.dropout_layer(action_summary)
                    buffer_summary = self.dropout_layer(buffer_summary)

                # feature for index of last open non-terminal
                last_open_NT_feature = torch.zeros(len(self.actions_vocab))
                open_NT_exists = state.num_open_NT > 0

                if (
                    len(stack) > 0
                    and open_NT_exists
                    and self.ablation_use_last_open_NT_feature
                ):
                    last_open_NT = None
                    try:
                        open_NT = state.is_open_NT[::-1].index(True)
                        last_open_NT = stack.element_from_top(open_NT)
                    except ValueError:
                        pass
                    if last_open_NT:
                        last_open_NT_feature[last_open_NT.node] = 1.0
                last_open_NT_feature = last_open_NT_feature.unsqueeze(0)

                summaries = []
                if self.ablation_use_buffer:
                    summaries.append(buffer_summary)
                if self.ablation_use_stack:
                    summaries.append(stack_summary)
                if self.ablation_use_action:
                    summaries.append(action_summary)
                if self.ablation_use_last_open_NT_feature:
                    summaries.append(last_open_NT_feature)

                state.action_p = self.action_linear(torch.cat(summaries, dim=1))

                log_probs = F.log_softmax(state.action_p, dim=1)[0]

                for action in self.valid_actions(state):
                    plans.append(
                        (state.neg_prob - log_probs[action].item(), state, action)
                    )

            beam = []
            # Take actions to regenerate the beam
            for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]:
                # Skip terminal states
                if state.finished():
                    beam.append(state)
                    continue

                # Only branch out states when needed
                if beam_size > 1:
                    state = state.copy()

                state.predicted_actions_idx.append(predicted_action_idx)

                target_action_idx = predicted_action_idx
                if self.training:
                    assert (
                        len(actions_idx_rev) > 0
                    ), "Actions and tokens may not be in sync."
                    target_action_idx = actions_idx_rev[-1]
                    actions_idx_rev = actions_idx_rev[:-1]

                if (
                    self.constraints_ignore_loss_for_unsupported
                    and state.found_unsupported
                ):
                    pass
                else:
                    state.action_scores.append(state.action_p)

                self.push_action(state, target_action_idx)

                state.neg_prob = neg_prob
                beam.append(state)
            # End for
        # End while
        assert len(beam) > 0, "How come beam is empty?"
        assert len(state.stack_stackrnn) == 1, "How come stack len is " + str(
            len(state.stack_stackrnn)
        )
        assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str(
            len(state.buffer_stackrnn)
        )

        # Unsqueeze to add batch dimension before returning.
        return [
            (
                cuda_utils.LongTensor(state.predicted_actions_idx).unsqueeze(0),
                torch.cat(state.action_scores).unsqueeze(0),
            )
            for state in sorted(beam)[:top_k]
        ]
Example #4
0
 def populate_buffer(self):
     state = ParserState(self.parser)
     for _ in range(2):
         state.buffer_stackrnn.push(torch.zeros(1, 30), Element("Token"))
     return state
Example #5
0
    def forward(
        self,
        tokens: torch.Tensor,
        seq_lens: torch.Tensor,
        dict_feat: Optional[Tuple[torch.Tensor, ...]] = None,
        actions: Optional[List[List[int]]] = None,
        beam_size: int = 1,
        topk: int = 1,
    ):
        """RNNG forward function.

        Args:
            tokens (torch.Tensor): list of tokens
            seq_lens (torch.Tensor): list of sequence lengths
            dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer
                features for each token
            actions (Optional[List[List[int]]]): Used only during training.
                Oracle actions for the instances.
            beam_size (int): Beam size; used only during inference
            topk (int) : Number of top results from the method.
                If beam_size is 1 this is 1.


        Returns:
            if topk == 1
                tuple of list of predicted actions and list of corresponding scores
            else
                list of tuple of list of predicted actions and list of \
                    corresponding scores


        """
        if self.training:
            assert beam_size == 1, "beam_size must be 1 during training"
            assert actions is not None, "actions must be provided for training"
            actions_idx_rev = list(reversed(actions[0]))
        else:
            torch.manual_seed(0)

        beam_size = max(beam_size, 1)

        # Reverse the order of indices along last axis before embedding lookup.
        tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1])
        dict_feat_rev = None
        if dict_feat:
            dict_ids, dict_weights, dict_lengths = dict_feat
            dict_ids_rev = torch.flip(dict_ids, [len(dict_ids.size()) - 1])
            dict_weights_rev = torch.flip(dict_weights, [len(dict_weights.size()) - 1])
            dict_lengths_rev = torch.flip(dict_lengths, [len(dict_lengths.size()) - 1])
            dict_feat_rev = (dict_ids_rev, dict_weights_rev, dict_lengths_rev)

        embedding_input = (
            [tokens_list_rev, dict_feat_rev]
            if dict_feat_rev is not None
            else [tokens_list_rev]
        )
        tok_embeddings = self.embedding(*embedding_input)

        # Batch size is always = 1. So we squeeze the batch_size dimension.
        tok_embeddings = tok_embeddings.squeeze(0)
        tokens_list_rev = tokens_list_rev.squeeze(0)

        initial_state = ParserState(self)
        for i in range(tok_embeddings.size()[0]):
            tok_embedding = tok_embeddings[i].unsqueeze(0)
            tok = tokens_list_rev[i]
            initial_state.buffer_stackrnn.push(tok_embedding, Element(tok))

        beam = [initial_state]
        while beam and any(not state.finished() for state in beam):
            # Stores plans for expansion as (score, state, action)
            plans: List[Tuple[float, ParserState, int]] = []
            # Expand current beam states
            for state in beam:
                # Keep terminal states
                if state.finished():
                    plans.append((state.neg_prob, state, -1))
                    continue

                #  translating Expression p_t = affine_transform({pbias, S,
                #  stack_summary, B, buffer_summary, A, action_summary});
                stack = state.stack_stackrnn
                stack_summary = stack.embedding()
                action_summary = state.action_stackrnn.embedding()
                buffer_summary = state.buffer_stackrnn.embedding()
                if self.dropout_layer.p > 0:
                    stack_summary = self.dropout_layer(stack_summary)
                    action_summary = self.dropout_layer(action_summary)
                    buffer_summary = self.dropout_layer(buffer_summary)

                # feature for index of last open non-terminal
                last_open_NT_feature = torch.zeros(len(self.actions_vocab))
                open_NT_exists = state.num_open_NT > 0

                if (
                    len(stack) > 0
                    and open_NT_exists
                    and self.ablation_use_last_open_NT_feature
                ):
                    last_open_NT = None
                    try:
                        open_NT = state.is_open_NT[::-1].index(True)
                        last_open_NT = stack.ele_from_top(open_NT)
                    except ValueError:
                        pass
                    if last_open_NT:
                        last_open_NT_feature[last_open_NT.node] = 1.0
                last_open_NT_feature = last_open_NT_feature.unsqueeze(0)

                summaries = []
                if self.ablation_use_buffer:
                    summaries.append(buffer_summary)
                if self.ablation_use_stack:
                    summaries.append(stack_summary)
                if self.ablation_use_action:
                    summaries.append(action_summary)
                if self.ablation_use_last_open_NT_feature:
                    summaries.append(last_open_NT_feature)

                action_p = self.action_linear(torch.cat(summaries, dim=1))

                log_probs = F.log_softmax(action_p, dim=1)[0]

                for action in self.valid_actions(state):
                    plans.append((state.neg_prob - log_probs[action], state, action))

            beam = []
            # Take actions to regenerate the beam
            for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]:
                # Skip terminal states
                if state.finished():
                    beam.append(state)
                    continue

                # Only branch out states when needed
                if beam_size > 1:
                    state = state.copy()

                state.predicted_actions_idx.append(predicted_action_idx)

                target_action_idx = predicted_action_idx
                if self.training:
                    assert (
                        len(actions_idx_rev) > 0
                    ), "Actions and tokens may not be in sync."
                    target_action_idx = actions_idx_rev[-1]
                    actions_idx_rev = actions_idx_rev[:-1]

                if (
                    self.constraints_ignore_loss_for_unsupported
                    and state.found_unsupported
                ):
                    pass
                else:
                    state.action_scores.append(action_p)

                action_embedding = self.actions_lookup(
                    cuda_utils.Variable(torch.LongTensor([target_action_idx]))
                )
                state.action_stackrnn.push(action_embedding, Element(target_action_idx))

                if target_action_idx == self.shift_idx:
                    state.is_open_NT.append(False)
                    tok_embedding, token = state.buffer_stackrnn.pop()
                    state.stack_stackrnn.push(tok_embedding, Element(token))
                elif target_action_idx == self.reduce_idx:
                    state.num_open_NT -= 1
                    popped_rep = []
                    nt_tree = []

                    while not state.is_open_NT[-1]:
                        assert len(state.stack_stackrnn) > 0, "How come stack is empty!"
                        state.is_open_NT.pop()
                        top_of_stack = state.stack_stackrnn.pop()
                        popped_rep.append(top_of_stack[0])
                        nt_tree.append(top_of_stack[1])

                    # pop the open NT and close it
                    top_of_stack = state.stack_stackrnn.pop()
                    popped_rep.append(top_of_stack[0])
                    nt_tree.append(top_of_stack[1])

                    state.is_open_NT.pop()
                    state.is_open_NT.append(False)

                    compostional_rep = self.p_compositional(popped_rep)
                    combinedElement = Element(nt_tree)

                    state.stack_stackrnn.push(compostional_rep, combinedElement)
                elif target_action_idx in self.valid_NT_idxs:

                    # if this is root prediction and if that root is one
                    # of the unsupported intents
                    if (
                        len(state.predicted_actions_idx) == 1
                        and target_action_idx in self.ignore_subNTs_roots
                    ):
                        state.found_unsupported = True

                    state.is_open_NT.append(True)
                    state.num_open_NT += 1
                    state.stack_stackrnn.push(
                        action_embedding, Element(target_action_idx)
                    )
                else:
                    assert "not a valid action: {}".format(
                        self.actions_vocab.itos[target_action_idx]
                    )

                state.neg_prob = neg_prob
                beam.append(state)
            # End for
        # End while
        assert len(beam) > 0, "How come beam is empty?"
        assert len(state.stack_stackrnn) == 1, "How come stack len is " + str(
            len(state.stack_stackrnn)
        )
        assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str(
            len(state.buffer_stackrnn)
        )

        # Add batch dimension before returning.
        if topk <= 1:
            state = min(beam)
            return (
                torch.LongTensor(state.predicted_actions_idx).unsqueeze(0),
                torch.cat(state.action_scores).unsqueeze(0),
            )
        else:
            return [
                (
                    torch.LongTensor(state.predicted_actions_idx).unsqueeze(0),
                    torch.cat(state.action_scores).unsqueeze(0),
                )
                for state in sorted(beam)[:topk]
            ]