Beispiel #1
0
    def forward(
            self,
            input: Tensor,
            state: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
        inputs = input.unbind(0)
        outputs = jit.annotate(List[Tensor], [])
        mask_r = jit.annotate(Tensor, torch.zeros(1))
        mask_i = jit.annotate(Tensor, torch.zeros(1))

        for t in range(len(inputs)):
            if self.training and self.inp_drop_p > 0.0:
                if t == 0:
                    mask_i = torch.rand(inputs[t].shape). \
                        bernoulli_(1 - self.recurrent_drop_p). \
                        div(1 - self.recurrent_drop_p)
                mask_i = mask_i.to(inputs[t].dtype).to(inputs[t].device)
                inputs[t] = mask_i * inputs[t]

            state = self.base_cell(inputs[t], state)
            outputs.append(state)

            if self.training and self.recurrent_drop_p > 0.0 and t < len(
                    inputs) - 1:
                if t == 0:
                    mask_r = torch.rand(state.shape). \
                        bernoulli_(1 - self.recurrent_drop_p). \
                        div(1 - self.recurrent_drop_p)
                mask_r = mask_r.to(state.dtype).to(state.device)
                state = mask_r * state

        return torch.stack(outputs, dim=0), state
Beispiel #2
0
    def forward(
        self,
        tokens: torch.Tensor,
        token_embeddings: torch.Tensor,
        actions: List[List[int]],
        beam_size: int = 1,
        top_k: int = 1,
    ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        actions_idx = jit.annotate(List[int], [])
        if self.training:
            # batch size is only 1 for now
            actions_idx = actions[0]
            assert len(
                actions_idx) > 0, "actions must be provided for training"
        else:
            torch.manual_seed(0)

        beam = [self.gen_init_state(tokens, token_embeddings)]
        all_finished = False
        while not all_finished:
            # Stores plans for expansion as (score, state, action)
            plans = jit.annotate(List[Plan], [])
            all_finished = True
            # Expand current beam states
            for state in beam:
                # Keep terminal states
                if state.finished():
                    plans.append(
                        Plan(state.neg_prob, const.TERMINAL_ELEMENT, state))
                else:
                    all_finished = False
                    plans.extend(self.gen_plans(state))

            beam.clear()
            # Take actions to regenerate the beam
            plans.sort()
            for plan in plans[:beam_size]:
                beam.append(self.execute_plan(plan, actions_idx, beam_size))

        # sanity check
        assert len(beam) > 0, "How come beam is empty?"

        beam.sort()
        res = jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
        for state in beam[:top_k]:
            res.append((
                torch.tensor([state.predicted_actions_idx],
                             device=self.device),
                # Unsqueeze to add batch dimension
                torch.cat(state.action_scores).unsqueeze(0),
            ))
        return res
Beispiel #3
0
 def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
     # List[LSTMState]: [forward LSTMState, backward LSTMState]
     outputs = jit.annotate(List[Tensor], [])
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
     i = 0
     for direction in self.directions:
         state = states[i]
         out, out_state = direction(input, state)
         outputs += [out]
         output_states += [out_state]
         i += 1
     return torch.cat(outputs, -1), output_states
Beispiel #4
0
 def forward(self, input, states):
     # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
     # List[LSTMState]: [forward LSTMState, backward LSTMState]
     outputs = jit.annotate(List[Tensor], [])
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     for (i, direction) in enumerate(self.directions):
         state = states[i]
         out, out_state = direction(input, state)
         outputs += [out]
         output_states += [out_state]
     # tensor array concat assumes axis == 0 for now
     # return torch.cat(outputs, -1), output_states
     return torch.cat(outputs, 0), output_states
Beispiel #5
0
    def forward(
        self, input_: Tensor, states: List[Union[Tuple[Tensor, Tensor],
                                                 Tensor]]
    ) -> Tuple[Tensor, List[Union[Tuple[Tensor, Tensor], Tensor]]]:
        # pylint: disable=arguments-differ
        # List[RNNState]: [forward RNNState, backward RNNState]
        outputs = jit.annotate(List[Tensor], [])
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])

        for direction, state in zip(self.directions, states):
            out, out_state = direction(input_, state)
            outputs += [out]
            output_states += [out_state]

        return cat(outputs, -1), output_states
Beispiel #6
0
 def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
     inputs = reverse(input.unbind(0))
     outputs = jit.annotate(List[Tensor], [])
     for i in range(len(inputs)):
         out, state = self.cell(inputs[i], state)
         outputs += [out]
     return torch.stack(reverse(outputs)), state
Beispiel #7
0
    def gen_plans(self, state: ParserState):
        plans = jit.annotate(List[Plan], [])
        # translating Expression p_t = affine_transform({pbias, S, stack_summary,
        # B, buffer_summary, A, action_summary});
        # list comprehension with ifs not supported by jit yet
        summaries = []
        for stack_tuple in (
            (state.stack_state_stack, self.ablation_use_stack),
            (state.buffer_state_stack, self.ablation_use_buffer),
            (state.action_state_stack, self.ablation_use_action),
        ):
            stack, flag = stack_tuple
            if flag:
                summaries.append(self.get_summary(stack))

        if self.ablation_use_last_open_NT_feature:
            # feature for index of last open non-terminal
            last_open_NT_feature = torch.zeros(self.num_actions)
            if len(state.open_NT) > 0:
                last_open_NT_feature[state.open_NT[-1]] = 1.0
            summaries.append(last_open_NT_feature.unsqueeze(0))

        state.action_p = self.action_linear(torch.cat(summaries, dim=1))
        log_probs = F.log_softmax(state.action_p, dim=1)[0]

        for action in self.valid_actions(state):
            plans.append(
                Plan(
                    score=state.neg_prob - int(log_probs[action].item()),
                    action=action,
                    state=state,
                ))
        return plans
Beispiel #8
0
 def forward(self, logits: torch.Tensor):
     # In pure python, this code would be implemented as follows:
     #   scores = self.score_function(logits)
     #   return [
     #     {class: score for class, score in zip(self.classes, example_scores}
     #     for example_scores in scores.tolist()
     #   ]
     # Extra verbosity is due to jit.script.
     scores = self.score_function(logits)
     results = jit.annotate(List[Dict[str, float]], [])
     for example_scores in scores.chunk(len(scores)):
         example_scores = example_scores.squeeze(dim=0)
         example_response = jit.annotate(Dict[str, float], {})
         for i in range(len(self.classes)):
             example_response[self.classes[i]] = example_scores[i].item()
         results.append(example_response)
     return results
Beispiel #9
0
    def tokenize(self, tokens: List[str]) -> List[str]:
        bpe_tokens = jit.annotate(List[str], [])

        for token in tokens:
            # extend not implemented
            for part in self.bpe_token(token):
                bpe_tokens.append(part)

        return bpe_tokens
Beispiel #10
0
 def forward(self, inputs, state):
     # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
     outputs = jit.annotate(List[Tensor], [])
     seq_len = inputs.size(0)
     for i in range(seq_len):
         out, state = self.cell(inputs[seq_len - i - 1], state)
         # workaround for the lack of list rev support
         outputs = [out] + outputs
     return torch.stack(outputs), state
Beispiel #11
0
 def forward(
     self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor]
 ) -> Tuple[Tensor, Union[Tuple[Tensor, Tensor], Tensor]]:
     # pylint: disable=arguments-differ
     inputs = self.reverse(input_.unbind(0))
     outputs = jit.annotate(List[Tensor], [])
     for input_values in inputs:
         out, state = self.cell(input_values, state)
         outputs += [out]
     return stack(self.reverse(outputs)), state
Beispiel #12
0
 def forward(self, input, state):
     # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
     inputs = torch.split(input, 1)
     inputs.reverse()
     outputs = jit.annotate(List[Tensor], [])
     for i in range(len(inputs)):
         out, state = self.cell(inputs[i], state)
         outputs += [out]
     outputs.reverse()
     return torch.stack(outputs), state
Beispiel #13
0
 def forward(self, input, states):
     # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
     # List[LSTMState]: One state per layer
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     output = input
     for (i, rnn_layer) in enumerate(self.layers):
         state = states[i]
         output, out_state = rnn_layer(output, state)
         output_states += [out_state]
     return output, output_states
Beispiel #14
0
 def forward(
         self, input_: Tensor,
         state: Tuple[Tensor,
                      Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
     # pylint: disable=arguments-differ
     inputs = input_.unbind(0)
     outputs = jit.annotate(List[Tensor], [])
     for input_item in inputs:
         out, state = self.cell(input_item, state)
         outputs += [out]
     return stack(outputs), state
Beispiel #15
0
 def forward(self, input, states):
     # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
     # List[List[LSTMState]]: The outer list is for layers,
     #                        inner list is for directions.
     output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
     output = input
     for (i, rnn_layer) in enumerate(self.layers):
         state = states[i]
         output, out_state = rnn_layer(output, state)
         output_states += [out_state]
     return output, output_states
Beispiel #16
0
 def lookup_words_1d(
     self, values: Tensor, filter_token_list: List[int] = ()) -> List[str]:
     result = jit.annotate(List[str], [])
     for idx in range(values.size(0)):
         value = int(values[idx])
         if not list_membership(value, filter_token_list):
             if value < len(self.vocab):
                 result.append(self.vocab[int(value)])
             else:
                 result.append(self.vocab[self.unk_idx])
     return result
Beispiel #17
0
    def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: [forward LSTMState, backward LSTMState]
        outputs = jit.annotate(List[Tensor], [])
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for direction in self.directions:
            state = states[i]
            out, out_state = direction(input, state)

            # print("- BidirLSTMLayer. out.shape={}, out_state[0].shape={}, out_state[1].shape={}".format(out.shape, out_state[0].shape, out_state[1].shape))

            outputs += [out]
            output_states += [out_state]
            i += 1

        # print("BidirLSTMLayer. len(output_states)={}".format(len(output_states)))

        return torch.cat(outputs, -1), output_states
Beispiel #18
0
    def __init__(
        self,
        buffer_stack: LSTMStateStack,
        stack_stack: LSTMStateStack,
        action_stack: LSTMStateStack,
    ):
        self.buffer_state_stack = buffer_stack
        self.stack_state_stack = stack_stack
        self.action_state_stack = action_stack

        self.predicted_actions_idx = jit.annotate(List[int], [])
        self.action_scores = []

        self.is_open_NT = jit.annotate(List[bool], [])
        self.open_NT = jit.annotate(List[int], [])
        self.found_unsupported = False
        # dummy tensor as place holder
        self.action_p = torch.zeros(1)

        # negative cumulative log prob so sort(states) is in descending order
        self.neg_prob = 0.0
Beispiel #19
0
 def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
     # List[LSTMState]: One state per layer
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     output = input
     # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
     i = 0
     for rnn_layer in self.layers:
         state = states[i]
         output, out_state = rnn_layer(output, state)
         output_states += [out_state]
         i += 1
     return output, output_states
Beispiel #20
0
 def forward(self, input, states=None, mask=None):
     # type: (Tensor, Optional[List[Tuple[Tensor, Tensor]]], Optional[Tensor]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     output = input
     # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
     i = 0
     for layer in self.layers:
         state = states[i] if states is not None else None
         output, out_state = layer(output, state, mask)
         output_states.append(out_state)
         i += 1
     return output, output_states
Beispiel #21
0
 def forward(
     self,
     tokens: List[str],
     dict_feat: Tuple[List[str], List[float], List[int]],
     contextual_token_embeddings: List[float],
     beam_size: int = 1,
     top_k: int = 1,
 ):
     token_ids = self.word_vocab.lookup_indices_1d(self.unkify(tokens))
     dict_tokens, dict_weights, dict_lengths = dict_feat
     dict_ids = self.dict_vocab.lookup_indices_1d(dict_tokens)
     token_ids_tensor = torch.tensor([token_ids])
     embed = self.embedding(
         token_ids_tensor,
         (
             torch.tensor([dict_ids]),
             torch.tensor([dict_weights], dtype=torch.float),
             torch.tensor([dict_lengths]),
         ),
         torch.tensor([contextual_token_embeddings], dtype=torch.float),
     )
     raw_results = self.jit_module(
         tokens=token_ids_tensor,
         token_embeddings=embed,
         actions=(),
         beam_size=beam_size,
         top_k=top_k,
     )
     results = jit.annotate(List[Tuple[List[str], List[float]]], [])
     for result in raw_results:
         actions, scores = result
         seq_logical = self.actions_to_seqlogical(actions.squeeze(0),
                                                  tokens)
         normalized_scores = F.softmax(scores, 2).max(2)[0].squeeze(0)
         float_scores = jit.annotate(List[float], [])
         # TODO this can be done more efficiently once JIT provide native support
         for idx in range(normalized_scores.size(0)):
             float_scores.append(float(normalized_scores[idx]))
         results.append((seq_logical, float_scores))
     return results
Beispiel #22
0
 def forward(
         self,
         input: Tensor,
         h_0: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
     output = jit.annotate(Tensor, torch.zeros(1))
     l = 1
     for module in self._modules_list:
         output, h_0 = module(input, h_0)
         if l % self.skip_length == 0:
             input = output + input
         else:
             input = output
         l += 1
     return output, h_0
Beispiel #23
0
 def actions_to_seqlogical(self, actions, tokens: List[str]):
     token_idx = 0
     res = jit.annotate(List[str], [])
     for idx in range(actions.size(0)):
         action = int(actions[idx])
         if action == self.jit_module.reduce_idx:
             res.append(self.CLOSE_BRACKET)
         elif action == self.jit_module.shift_idx:
             res.append(tokens[token_idx])
             token_idx += 1
         else:
             res.append(self.OPEN_BRACKET)
             res.append(self.action_vocab.lookup_word(action))
     return res
Beispiel #24
0
            def forward(self, tokens: List[List[str]]):
                word_ids = self.vocab.lookup_indices_2d(tokens)

                seq_lens = jit.annotate(List[int], [])

                for sentence in word_ids:
                    seq_lens.append(len(sentence))
                pad_to_length = list_max(seq_lens)
                for sentence in word_ids:
                    for _ in range(pad_to_length - len(sentence)):
                        sentence.append(self.pad_idx)

                logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens))
                return self.output_layer(logits)
Beispiel #25
0
 def forward(self, input, states):
     # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
     # List[List[LSTMState]]: The outer list is for layers,
     #                        inner list is for directions.
     output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
     output = input
     # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
     i = 0
     for rnn_layer in self.layers:
         state = states[i]
         output, out_state = rnn_layer(output, state)
         output_states += [out_state]
         i += 1
     return output, output_states
Beispiel #26
0
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]

        # print("ReverseLSTMLayer.forward. input.shape={}".format(input.shape))

        inputs = reverse(input.unbind(0))
        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]

        # print("ReverseLSTMLayer. len(state)={}".format(len(state)))

        return torch.stack(reverse(outputs)), state
Beispiel #27
0
    def forward(self, input_: Tensor, states: List) -> Tuple[Tensor, List]:
        # pylint: disable=arguments-differ
        # List[RNNState]: One state per layer.
        output_states = jit.annotate(List, [])
        output = input_

        for i, rnn_layer in enumerate(self.layers):
            state = states[i]
            output, out_state = rnn_layer(output, state)
            # Apply the dropout layer except the last layer.
            if i < self.num_layers - 1 and self.dropout_layer is not None:
                output = self.dropout_layer(output)
            output_states += [out_state]

        return output, output_states
Beispiel #28
0
 def forward(self, input, states):
     # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
     # List[LSTMState]: One state per layer
     output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
     output = input
     # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
     i = 0
     for rnn_layer in self.layers:
         state = states[i]
         output, out_state = rnn_layer(output, state)
         # Apply the dropout layer except the last layer
         if i < self.num_layers - 1:
             output = self.dropout_layer(output)
         output_states += [out_state]
         i += 1
     return output, output_states
Beispiel #29
0
    def valid_actions(self, state: ParserState) -> List[int]:

        valid_actions = jit.annotate(List[int], [])
        is_open_NT = state.is_open_NT
        num_open_NT = len(state.open_NT)
        stack = state.stack_state_stack
        buffer = state.buffer_state_stack

        # Can REDUCE if
        # 1. Top of multi-element stack is not an NT, and
        # 2. Two open NTs on stack, or buffer is empty
        if (len(is_open_NT) > 0 and not is_open_NT[-1]
                and not len(is_open_NT) == 1) and (num_open_NT >= 2
                                                   or buffer.size() == 0):
            assert stack.size() > 0
            valid_actions.append(self.reduce_idx)

        if buffer.size() > 0 and num_open_NT < self.max_open_NT:
            if (not self.training) or self.constraints_intent_slot_nesting:
                # if stack is empty or the last open NT is slot
                if (len(state.open_NT) == 0) or list_membership(
                        state.open_NT[-1], self.valid_SL_idxs):
                    valid_actions += self.valid_IN_idxs
                elif list_membership(state.open_NT[-1], self.valid_IN_idxs):
                    if not (self.constraints_no_slots_inside_unsupported
                            and state.found_unsupported):
                        valid_actions += self.valid_SL_idxs
            else:
                valid_actions.extend(self.valid_IN_idxs)
                valid_actions.extend(self.valid_SL_idxs)

        elif (not self.training) and num_open_NT >= self.max_open_NT:
            print("not predicting NT, buffer len is {}, num open NTs is {}".
                  format(buffer.size(), num_open_NT))

        # Can SHIFT if
        # 1. Buffer is non-empty, and
        # 2. At least one open NT on stack
        if buffer.size() > 0 and num_open_NT >= 1:
            valid_actions.append(self.shift_idx)

        return valid_actions
Beispiel #30
0
 def forward(self, inputs, fhiddens, bhiddens, fstates, bstates):
     outputs = jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
     outputf = inputs
     outputb = inputs
     i = 0
     for layer1, layer2, dropout1, dropout2 in zip(self.forward_model,
                                                   self.backward_model,
                                                   self.dropoutf,
                                                   self.dropoutb):
         fstate = fstates[i]
         bstate = bstates[i]
         fhidden = fhiddens[i]
         bhidden = bhiddens[i]
         outputf, fstate = layer1(outputf, fstate, fhidden)
         outputb, bstate = layer2(outputb, bstate, bhidden)
         outputf = dropout1(outputf)
         outputb = dropout2(outputb)
         i += 1
         outputs += [(fstate, bstate)]
     return torch.cat((outputf, outputb), -1), outputs