Exemple #1
0
    def get_loss(self, logits: torch.Tensor, target_actions: torch.Tensor,
                 context: torch.Tensor):
        """
        Shapes:
            logits[1]: action scores: (1, sequence_length, number_of_actions)
            target_actions: (1, sequence_length)
        """

        # Supports beam search to check if there are top K predictions
        # (there will be an extra dimension)
        try:
            top_k_exists = logits[0][0][0][0]
            if top_k_exists:
                action_scores = logits[0][1].squeeze(0)
                target_actions = target_actions[0].squeeze(0)

        except (TypeError, IndexError):
            # Get rid of the batch dimension
            action_scores = logits[1].squeeze(0)
            target_actions = target_actions.squeeze(0)

        action_scores_list = torch.chunk(action_scores,
                                         action_scores.size()[0])
        target_vars = [
            cuda_utils.Variable(torch.LongTensor([t])) for t in target_actions
        ]
        losses = [
            self.loss_func(action, target).view(1)
            for action, target in zip(action_scores_list, target_vars)
        ]
        total_loss = torch.sum(torch.cat(losses)) if len(losses) > 0 else None
        return total_loss
Exemple #2
0
    def _word_forward(self, inputs: torch.Tensor,
                      word_idx: int) -> torch.Tensor:
        # inputs -> (batch, words, embed_dim)
        start_idx = word_idx - self.fwd_bwd_ctxt_len
        word_with_bwd_context = inputs.narrow(1, start_idx,
                                              self.fwd_bwd_ctxt_len + 1)

        word_with_fwd_context = inputs.narrow(1, word_idx,
                                              self.fwd_bwd_ctxt_len + 1)

        start_idx = word_idx - self.surr_ctxt_len
        word_with_surr_context = inputs.narrow(1, start_idx,
                                               2 * self.surr_ctxt_len + 1)

        padding = cuda_utils.Variable(
            torch.cat([self.padding_tensor] * inputs.size()[0]))
        conv_in_bwd_context = torch.cat((word_with_bwd_context, padding),
                                        dim=1)
        conv_in_fwd_context = torch.cat((padding, word_with_fwd_context),
                                        dim=1)

        bwd_ctxt_rep = self._conv_maxpool(conv_in_bwd_context, self.convs_bwd)
        fwd_ctxt_rep = self._conv_maxpool(conv_in_fwd_context, self.convs_fwd)
        surr_ctxt_rep = self._conv_maxpool(word_with_surr_context,
                                           self.convs_surr)

        # Full representation by combining all contextual represenations.
        return torch.cat(
            (
                self.bwd_fc(bwd_ctxt_rep),
                self.fwd_fc(fwd_ctxt_rep),
                self.surr_fc(surr_ctxt_rep),
            ),
            dim=1,
        )
Exemple #3
0
    def get_loss(
        self, logits: torch.Tensor, target_actions: torch.Tensor, context: torch.Tensor
    ):
        # action scores is a 2D Tensor of dims sequence_length x number_of_actions
        # targets is a 1D list of integers of length sequence_length

        # Get rid of the batch dimension
        action_scores = logits[1].squeeze(0)
        target_actions = target_actions.squeeze(0)

        action_scores_list = torch.chunk(action_scores, action_scores.size()[0])
        target_vars = [
            cuda_utils.Variable(torch.LongTensor([t])) for t in target_actions
        ]
        losses = [
            self.loss_func(action, target).view(1)
            for action, target in zip(action_scores_list, target_vars)
        ]
        total_loss = torch.sum(torch.cat(losses)) if len(losses) > 0 else None
        return total_loss
Exemple #4
0
    def __init__(self, config: Config, embed_dim: int) -> None:
        super().__init__(config)

        self.fwd_bwd_ctxt_len = config.fwd_bwd_context_len
        self.surr_ctxt_len = config.surrounding_context_len
        self.ctxt_pad_len = max(self.fwd_bwd_ctxt_len, self.surr_ctxt_len)
        self.padding_tensor = cuda_utils.Variable(torch.Tensor(
            1, self.fwd_bwd_ctxt_len, embed_dim),
                                                  requires_grad=False)
        self.padding_tensor.fill_(0)

        bwd_convs, fwd_convs, surr_convs = [], [], []
        in_channels = 1
        out_channels = config.cnn.kernel_num
        kernel_sizes = config.cnn.kernel_sizes
        for k in kernel_sizes:
            bwd_convs.append(
                nn.Conv2d(in_channels, out_channels, (k, embed_dim)))
            fwd_convs.append(
                nn.Conv2d(in_channels, out_channels, (k, embed_dim)))
            surr_convs.append(
                nn.Conv2d(in_channels, out_channels, (k, embed_dim)))
        self.convs_bwd = nn.ModuleList(bwd_convs)
        self.convs_fwd = nn.ModuleList(fwd_convs)
        self.convs_surr = nn.ModuleList(surr_convs)

        # Token representation size with each context.
        token_rep_len = len(kernel_sizes) * out_channels
        self.bwd_fc = nn.Linear(token_rep_len, token_rep_len)
        self.fwd_fc = nn.Linear(token_rep_len, token_rep_len)
        self.surr_fc = nn.Linear(token_rep_len, token_rep_len)

        self.ctxt_pad = nn.ConstantPad1d(
            (self.ctxt_pad_len, self.ctxt_pad_len), 0)

        self.representation_dim = 3 * len(kernel_sizes) * out_channels
Exemple #5
0
    def push_action(self, state: ParserState, target_action_idx: int) -> None:
        """Used for updating the state with a target next action

        Args:
            state (ParserState): The state of the stack, buffer and action
            target_action_idx (int): Index of the action to process
        """

        # Update action_stackrnn
        action_embedding = self.actions_lookup(
            cuda_utils.Variable(torch.LongTensor([target_action_idx])))
        state.action_stackrnn.push(action_embedding,
                                   Element(target_action_idx))

        # Update stack_stackrnn
        if target_action_idx == self.shift_idx:
            # To SHIFT,
            # 1. Pop T from buffer
            # 2. Push T into stack
            state.is_open_NT.append(False)
            token_embedding, token = state.buffer_stackrnn.pop()
            state.stack_stackrnn.push(token_embedding, Element(token))

        elif target_action_idx == self.reduce_idx:
            # To REDUCE
            # 1. Pop Ts from stack until hit NT
            # 2. Pop the open NT from stack and close it
            # 3. Compute compositionalRep and push into stack
            state.num_open_NT -= 1
            popped_rep = []
            nt_tree = []

            while not state.is_open_NT[-1]:
                assert len(
                    state.stack_stackrnn) > 0, "How come stack is empty!"
                state.is_open_NT.pop()
                top_of_stack = state.stack_stackrnn.pop()
                popped_rep.append(top_of_stack[0])
                nt_tree.append(top_of_stack[1])

            # pop the open NT and close it
            top_of_stack = state.stack_stackrnn.pop()
            popped_rep.append(top_of_stack[0])
            nt_tree.append(top_of_stack[1])

            state.is_open_NT.pop()
            state.is_open_NT.append(False)

            compostional_rep = self.p_compositional(popped_rep)
            combinedElement = Element(nt_tree)

            state.stack_stackrnn.push(compostional_rep, combinedElement)

        elif target_action_idx in self.valid_NT_idxs:

            # if this is root prediction and if that root is one
            # of the unsupported intents
            if (len(state.predicted_actions_idx) == 1
                    and target_action_idx in self.ignore_subNTs_roots):
                state.found_unsupported = True

            state.is_open_NT.append(True)
            state.num_open_NT += 1
            state.stack_stackrnn.push(action_embedding,
                                      Element(target_action_idx))
        else:
            assert "not a valid action: {}".format(
                self.actions_vocab.itos[target_action_idx])
Exemple #6
0
    def forward(
        self,
        tokens: torch.Tensor,
        seq_lens: torch.Tensor,
        dict_feat: Optional[Tuple[torch.Tensor, ...]] = None,
        actions: Optional[List[List[int]]] = None,
        beam_size: int = 1,
        topk: int = 1,
    ):
        """RNNG forward function.

        Args:
            tokens (torch.Tensor): list of tokens
            seq_lens (torch.Tensor): list of sequence lengths
            dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer
                features for each token
            actions (Optional[List[List[int]]]): Used only during training.
                Oracle actions for the instances.
            beam_size (int): Beam size; used only during inference
            topk (int) : Number of top results from the method.
                If beam_size is 1 this is 1.


        Returns:
            if topk == 1
                tuple of list of predicted actions and list of corresponding scores
            else
                list of tuple of list of predicted actions and list of \
                    corresponding scores


        """
        if self.training:
            assert beam_size == 1, "beam_size must be 1 during training"
            assert actions is not None, "actions must be provided for training"
            actions_idx_rev = list(reversed(actions[0]))
        else:
            torch.manual_seed(0)

        beam_size = max(beam_size, 1)

        # Reverse the order of indices along last axis before embedding lookup.
        tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1])
        dict_feat_rev = None
        if dict_feat:
            dict_ids, dict_weights, dict_lengths = dict_feat
            dict_ids_rev = torch.flip(dict_ids, [len(dict_ids.size()) - 1])
            dict_weights_rev = torch.flip(dict_weights, [len(dict_weights.size()) - 1])
            dict_lengths_rev = torch.flip(dict_lengths, [len(dict_lengths.size()) - 1])
            dict_feat_rev = (dict_ids_rev, dict_weights_rev, dict_lengths_rev)

        embedding_input = (
            [tokens_list_rev, dict_feat_rev]
            if dict_feat_rev is not None
            else [tokens_list_rev]
        )
        tok_embeddings = self.embedding(*embedding_input)

        # Batch size is always = 1. So we squeeze the batch_size dimension.
        tok_embeddings = tok_embeddings.squeeze(0)
        tokens_list_rev = tokens_list_rev.squeeze(0)

        initial_state = ParserState(self)
        for i in range(tok_embeddings.size()[0]):
            tok_embedding = tok_embeddings[i].unsqueeze(0)
            tok = tokens_list_rev[i]
            initial_state.buffer_stackrnn.push(tok_embedding, Element(tok))

        beam = [initial_state]
        while beam and any(not state.finished() for state in beam):
            # Stores plans for expansion as (score, state, action)
            plans: List[Tuple[float, ParserState, int]] = []
            # Expand current beam states
            for state in beam:
                # Keep terminal states
                if state.finished():
                    plans.append((state.neg_prob, state, -1))
                    continue

                #  translating Expression p_t = affine_transform({pbias, S,
                #  stack_summary, B, buffer_summary, A, action_summary});
                stack = state.stack_stackrnn
                stack_summary = stack.embedding()
                action_summary = state.action_stackrnn.embedding()
                buffer_summary = state.buffer_stackrnn.embedding()
                if self.dropout_layer.p > 0:
                    stack_summary = self.dropout_layer(stack_summary)
                    action_summary = self.dropout_layer(action_summary)
                    buffer_summary = self.dropout_layer(buffer_summary)

                # feature for index of last open non-terminal
                last_open_NT_feature = torch.zeros(len(self.actions_vocab))
                open_NT_exists = state.num_open_NT > 0

                if (
                    len(stack) > 0
                    and open_NT_exists
                    and self.ablation_use_last_open_NT_feature
                ):
                    last_open_NT = None
                    try:
                        open_NT = state.is_open_NT[::-1].index(True)
                        last_open_NT = stack.ele_from_top(open_NT)
                    except ValueError:
                        pass
                    if last_open_NT:
                        last_open_NT_feature[last_open_NT.node] = 1.0
                last_open_NT_feature = last_open_NT_feature.unsqueeze(0)

                summaries = []
                if self.ablation_use_buffer:
                    summaries.append(buffer_summary)
                if self.ablation_use_stack:
                    summaries.append(stack_summary)
                if self.ablation_use_action:
                    summaries.append(action_summary)
                if self.ablation_use_last_open_NT_feature:
                    summaries.append(last_open_NT_feature)

                action_p = self.action_linear(torch.cat(summaries, dim=1))

                log_probs = F.log_softmax(action_p, dim=1)[0]

                for action in self.valid_actions(state):
                    plans.append((state.neg_prob - log_probs[action], state, action))

            beam = []
            # Take actions to regenerate the beam
            for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]:
                # Skip terminal states
                if state.finished():
                    beam.append(state)
                    continue

                # Only branch out states when needed
                if beam_size > 1:
                    state = state.copy()

                state.predicted_actions_idx.append(predicted_action_idx)

                target_action_idx = predicted_action_idx
                if self.training:
                    assert (
                        len(actions_idx_rev) > 0
                    ), "Actions and tokens may not be in sync."
                    target_action_idx = actions_idx_rev[-1]
                    actions_idx_rev = actions_idx_rev[:-1]

                if (
                    self.constraints_ignore_loss_for_unsupported
                    and state.found_unsupported
                ):
                    pass
                else:
                    state.action_scores.append(action_p)

                action_embedding = self.actions_lookup(
                    cuda_utils.Variable(torch.LongTensor([target_action_idx]))
                )
                state.action_stackrnn.push(action_embedding, Element(target_action_idx))

                if target_action_idx == self.shift_idx:
                    state.is_open_NT.append(False)
                    tok_embedding, token = state.buffer_stackrnn.pop()
                    state.stack_stackrnn.push(tok_embedding, Element(token))
                elif target_action_idx == self.reduce_idx:
                    state.num_open_NT -= 1
                    popped_rep = []
                    nt_tree = []

                    while not state.is_open_NT[-1]:
                        assert len(state.stack_stackrnn) > 0, "How come stack is empty!"
                        state.is_open_NT.pop()
                        top_of_stack = state.stack_stackrnn.pop()
                        popped_rep.append(top_of_stack[0])
                        nt_tree.append(top_of_stack[1])

                    # pop the open NT and close it
                    top_of_stack = state.stack_stackrnn.pop()
                    popped_rep.append(top_of_stack[0])
                    nt_tree.append(top_of_stack[1])

                    state.is_open_NT.pop()
                    state.is_open_NT.append(False)

                    compostional_rep = self.p_compositional(popped_rep)
                    combinedElement = Element(nt_tree)

                    state.stack_stackrnn.push(compostional_rep, combinedElement)
                elif target_action_idx in self.valid_NT_idxs:

                    # if this is root prediction and if that root is one
                    # of the unsupported intents
                    if (
                        len(state.predicted_actions_idx) == 1
                        and target_action_idx in self.ignore_subNTs_roots
                    ):
                        state.found_unsupported = True

                    state.is_open_NT.append(True)
                    state.num_open_NT += 1
                    state.stack_stackrnn.push(
                        action_embedding, Element(target_action_idx)
                    )
                else:
                    assert "not a valid action: {}".format(
                        self.actions_vocab.itos[target_action_idx]
                    )

                state.neg_prob = neg_prob
                beam.append(state)
            # End for
        # End while
        assert len(beam) > 0, "How come beam is empty?"
        assert len(state.stack_stackrnn) == 1, "How come stack len is " + str(
            len(state.stack_stackrnn)
        )
        assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str(
            len(state.buffer_stackrnn)
        )

        # Add batch dimension before returning.
        if topk <= 1:
            state = min(beam)
            return (
                torch.LongTensor(state.predicted_actions_idx).unsqueeze(0),
                torch.cat(state.action_scores).unsqueeze(0),
            )
        else:
            return [
                (
                    torch.LongTensor(state.predicted_actions_idx).unsqueeze(0),
                    torch.cat(state.action_scores).unsqueeze(0),
                )
                for state in sorted(beam)[:topk]
            ]