def forward( self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: inputs = input.unbind(0) outputs = jit.annotate(List[Tensor], []) mask_r = jit.annotate(Tensor, torch.zeros(1)) mask_i = jit.annotate(Tensor, torch.zeros(1)) for t in range(len(inputs)): if self.training and self.inp_drop_p > 0.0: if t == 0: mask_i = torch.rand(inputs[t].shape). \ bernoulli_(1 - self.recurrent_drop_p). \ div(1 - self.recurrent_drop_p) mask_i = mask_i.to(inputs[t].dtype).to(inputs[t].device) inputs[t] = mask_i * inputs[t] state = self.base_cell(inputs[t], state) outputs.append(state) if self.training and self.recurrent_drop_p > 0.0 and t < len( inputs) - 1: if t == 0: mask_r = torch.rand(state.shape). \ bernoulli_(1 - self.recurrent_drop_p). \ div(1 - self.recurrent_drop_p) mask_r = mask_r.to(state.dtype).to(state.device) state = mask_r * state return torch.stack(outputs, dim=0), state
def forward( self, tokens: torch.Tensor, token_embeddings: torch.Tensor, actions: List[List[int]], beam_size: int = 1, top_k: int = 1, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: actions_idx = jit.annotate(List[int], []) if self.training: # batch size is only 1 for now actions_idx = actions[0] assert len( actions_idx) > 0, "actions must be provided for training" else: torch.manual_seed(0) beam = [self.gen_init_state(tokens, token_embeddings)] all_finished = False while not all_finished: # Stores plans for expansion as (score, state, action) plans = jit.annotate(List[Plan], []) all_finished = True # Expand current beam states for state in beam: # Keep terminal states if state.finished(): plans.append( Plan(state.neg_prob, const.TERMINAL_ELEMENT, state)) else: all_finished = False plans.extend(self.gen_plans(state)) beam.clear() # Take actions to regenerate the beam plans.sort() for plan in plans[:beam_size]: beam.append(self.execute_plan(plan, actions_idx, beam_size)) # sanity check assert len(beam) > 0, "How come beam is empty?" beam.sort() res = jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], []) for state in beam[:top_k]: res.append(( torch.tensor([state.predicted_actions_idx], device=self.device), # Unsqueeze to add batch dimension torch.cat(state.action_scores).unsqueeze(0), )) return res
def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for direction in self.directions: state = states[i] out, out_state = direction(input, state) outputs += [out] output_states += [out_state] i += 1 return torch.cat(outputs, -1), output_states
def forward(self, input, states): # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) for (i, direction) in enumerate(self.directions): state = states[i] out, out_state = direction(input, state) outputs += [out] output_states += [out_state] # tensor array concat assumes axis == 0 for now # return torch.cat(outputs, -1), output_states return torch.cat(outputs, 0), output_states
def forward( self, input_: Tensor, states: List[Union[Tuple[Tensor, Tensor], Tensor]] ) -> Tuple[Tensor, List[Union[Tuple[Tensor, Tensor], Tensor]]]: # pylint: disable=arguments-differ # List[RNNState]: [forward RNNState, backward RNNState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) for direction, state in zip(self.directions, states): out, out_state = direction(input_, state) outputs += [out] output_states += [out_state] return cat(outputs, -1), output_states
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: inputs = reverse(input.unbind(0)) outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): out, state = self.cell(inputs[i], state) outputs += [out] return torch.stack(reverse(outputs)), state
def gen_plans(self, state: ParserState): plans = jit.annotate(List[Plan], []) # translating Expression p_t = affine_transform({pbias, S, stack_summary, # B, buffer_summary, A, action_summary}); # list comprehension with ifs not supported by jit yet summaries = [] for stack_tuple in ( (state.stack_state_stack, self.ablation_use_stack), (state.buffer_state_stack, self.ablation_use_buffer), (state.action_state_stack, self.ablation_use_action), ): stack, flag = stack_tuple if flag: summaries.append(self.get_summary(stack)) if self.ablation_use_last_open_NT_feature: # feature for index of last open non-terminal last_open_NT_feature = torch.zeros(self.num_actions) if len(state.open_NT) > 0: last_open_NT_feature[state.open_NT[-1]] = 1.0 summaries.append(last_open_NT_feature.unsqueeze(0)) state.action_p = self.action_linear(torch.cat(summaries, dim=1)) log_probs = F.log_softmax(state.action_p, dim=1)[0] for action in self.valid_actions(state): plans.append( Plan( score=state.neg_prob - int(log_probs[action].item()), action=action, state=state, )) return plans
def forward(self, logits: torch.Tensor): # In pure python, this code would be implemented as follows: # scores = self.score_function(logits) # return [ # {class: score for class, score in zip(self.classes, example_scores} # for example_scores in scores.tolist() # ] # Extra verbosity is due to jit.script. scores = self.score_function(logits) results = jit.annotate(List[Dict[str, float]], []) for example_scores in scores.chunk(len(scores)): example_scores = example_scores.squeeze(dim=0) example_response = jit.annotate(Dict[str, float], {}) for i in range(len(self.classes)): example_response[self.classes[i]] = example_scores[i].item() results.append(example_response) return results
def tokenize(self, tokens: List[str]) -> List[str]: bpe_tokens = jit.annotate(List[str], []) for token in tokens: # extend not implemented for part in self.bpe_token(token): bpe_tokens.append(part) return bpe_tokens
def forward(self, inputs, state): # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] outputs = jit.annotate(List[Tensor], []) seq_len = inputs.size(0) for i in range(seq_len): out, state = self.cell(inputs[seq_len - i - 1], state) # workaround for the lack of list rev support outputs = [out] + outputs return torch.stack(outputs), state
def forward( self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor] ) -> Tuple[Tensor, Union[Tuple[Tensor, Tensor], Tensor]]: # pylint: disable=arguments-differ inputs = self.reverse(input_.unbind(0)) outputs = jit.annotate(List[Tensor], []) for input_values in inputs: out, state = self.cell(input_values, state) outputs += [out] return stack(self.reverse(outputs)), state
def forward(self, input, state): # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] inputs = torch.split(input, 1) inputs.reverse() outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): out, state = self.cell(inputs[i], state) outputs += [out] outputs.reverse() return torch.stack(outputs), state
def forward(self, input, states): # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input for (i, rnn_layer) in enumerate(self.layers): state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] return output, output_states
def forward( self, input_: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # pylint: disable=arguments-differ inputs = input_.unbind(0) outputs = jit.annotate(List[Tensor], []) for input_item in inputs: out, state = self.cell(input_item, state) outputs += [out] return stack(outputs), state
def forward(self, input, states): # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] # List[List[LSTMState]]: The outer list is for layers, # inner list is for directions. output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) output = input for (i, rnn_layer) in enumerate(self.layers): state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] return output, output_states
def lookup_words_1d( self, values: Tensor, filter_token_list: List[int] = ()) -> List[str]: result = jit.annotate(List[str], []) for idx in range(values.size(0)): value = int(values[idx]) if not list_membership(value, filter_token_list): if value < len(self.vocab): result.append(self.vocab[int(value)]) else: result.append(self.vocab[self.unk_idx]) return result
def forward(self, input, states): # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for direction in self.directions: state = states[i] out, out_state = direction(input, state) # print("- BidirLSTMLayer. out.shape={}, out_state[0].shape={}, out_state[1].shape={}".format(out.shape, out_state[0].shape, out_state[1].shape)) outputs += [out] output_states += [out_state] i += 1 # print("BidirLSTMLayer. len(output_states)={}".format(len(output_states))) return torch.cat(outputs, -1), output_states
def __init__( self, buffer_stack: LSTMStateStack, stack_stack: LSTMStateStack, action_stack: LSTMStateStack, ): self.buffer_state_stack = buffer_stack self.stack_state_stack = stack_stack self.action_state_stack = action_stack self.predicted_actions_idx = jit.annotate(List[int], []) self.action_scores = [] self.is_open_NT = jit.annotate(List[bool], []) self.open_NT = jit.annotate(List[int], []) self.found_unsupported = False # dummy tensor as place holder self.action_p = torch.zeros(1) # negative cumulative log prob so sort(states) is in descending order self.neg_prob = 0.0
def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for rnn_layer in self.layers: state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] i += 1 return output, output_states
def forward(self, input, states=None, mask=None): # type: (Tensor, Optional[List[Tuple[Tensor, Tensor]]], Optional[Tensor]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for layer in self.layers: state = states[i] if states is not None else None output, out_state = layer(output, state, mask) output_states.append(out_state) i += 1 return output, output_states
def forward( self, tokens: List[str], dict_feat: Tuple[List[str], List[float], List[int]], contextual_token_embeddings: List[float], beam_size: int = 1, top_k: int = 1, ): token_ids = self.word_vocab.lookup_indices_1d(self.unkify(tokens)) dict_tokens, dict_weights, dict_lengths = dict_feat dict_ids = self.dict_vocab.lookup_indices_1d(dict_tokens) token_ids_tensor = torch.tensor([token_ids]) embed = self.embedding( token_ids_tensor, ( torch.tensor([dict_ids]), torch.tensor([dict_weights], dtype=torch.float), torch.tensor([dict_lengths]), ), torch.tensor([contextual_token_embeddings], dtype=torch.float), ) raw_results = self.jit_module( tokens=token_ids_tensor, token_embeddings=embed, actions=(), beam_size=beam_size, top_k=top_k, ) results = jit.annotate(List[Tuple[List[str], List[float]]], []) for result in raw_results: actions, scores = result seq_logical = self.actions_to_seqlogical(actions.squeeze(0), tokens) normalized_scores = F.softmax(scores, 2).max(2)[0].squeeze(0) float_scores = jit.annotate(List[float], []) # TODO this can be done more efficiently once JIT provide native support for idx in range(normalized_scores.size(0)): float_scores.append(float(normalized_scores[idx])) results.append((seq_logical, float_scores)) return results
def forward( self, input: Tensor, h_0: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: output = jit.annotate(Tensor, torch.zeros(1)) l = 1 for module in self._modules_list: output, h_0 = module(input, h_0) if l % self.skip_length == 0: input = output + input else: input = output l += 1 return output, h_0
def actions_to_seqlogical(self, actions, tokens: List[str]): token_idx = 0 res = jit.annotate(List[str], []) for idx in range(actions.size(0)): action = int(actions[idx]) if action == self.jit_module.reduce_idx: res.append(self.CLOSE_BRACKET) elif action == self.jit_module.shift_idx: res.append(tokens[token_idx]) token_idx += 1 else: res.append(self.OPEN_BRACKET) res.append(self.action_vocab.lookup_word(action)) return res
def forward(self, tokens: List[List[str]]): word_ids = self.vocab.lookup_indices_2d(tokens) seq_lens = jit.annotate(List[int], []) for sentence in word_ids: seq_lens.append(len(sentence)) pad_to_length = list_max(seq_lens) for sentence in word_ids: for _ in range(pad_to_length - len(sentence)): sentence.append(self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)
def forward(self, input, states): # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] # List[List[LSTMState]]: The outer list is for layers, # inner list is for directions. output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for rnn_layer in self.layers: state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] i += 1 return output, output_states
def forward(self, input, state): # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # print("ReverseLSTMLayer.forward. input.shape={}".format(input.shape)) inputs = reverse(input.unbind(0)) outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): out, state = self.cell(inputs[i], state) outputs += [out] # print("ReverseLSTMLayer. len(state)={}".format(len(state))) return torch.stack(reverse(outputs)), state
def forward(self, input_: Tensor, states: List) -> Tuple[Tensor, List]: # pylint: disable=arguments-differ # List[RNNState]: One state per layer. output_states = jit.annotate(List, []) output = input_ for i, rnn_layer in enumerate(self.layers): state = states[i] output, out_state = rnn_layer(output, state) # Apply the dropout layer except the last layer. if i < self.num_layers - 1 and self.dropout_layer is not None: output = self.dropout_layer(output) output_states += [out_state] return output, output_states
def forward(self, input, states): # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for rnn_layer in self.layers: state = states[i] output, out_state = rnn_layer(output, state) # Apply the dropout layer except the last layer if i < self.num_layers - 1: output = self.dropout_layer(output) output_states += [out_state] i += 1 return output, output_states
def valid_actions(self, state: ParserState) -> List[int]: valid_actions = jit.annotate(List[int], []) is_open_NT = state.is_open_NT num_open_NT = len(state.open_NT) stack = state.stack_state_stack buffer = state.buffer_state_stack # Can REDUCE if # 1. Top of multi-element stack is not an NT, and # 2. Two open NTs on stack, or buffer is empty if (len(is_open_NT) > 0 and not is_open_NT[-1] and not len(is_open_NT) == 1) and (num_open_NT >= 2 or buffer.size() == 0): assert stack.size() > 0 valid_actions.append(self.reduce_idx) if buffer.size() > 0 and num_open_NT < self.max_open_NT: if (not self.training) or self.constraints_intent_slot_nesting: # if stack is empty or the last open NT is slot if (len(state.open_NT) == 0) or list_membership( state.open_NT[-1], self.valid_SL_idxs): valid_actions += self.valid_IN_idxs elif list_membership(state.open_NT[-1], self.valid_IN_idxs): if not (self.constraints_no_slots_inside_unsupported and state.found_unsupported): valid_actions += self.valid_SL_idxs else: valid_actions.extend(self.valid_IN_idxs) valid_actions.extend(self.valid_SL_idxs) elif (not self.training) and num_open_NT >= self.max_open_NT: print("not predicting NT, buffer len is {}, num open NTs is {}". format(buffer.size(), num_open_NT)) # Can SHIFT if # 1. Buffer is non-empty, and # 2. At least one open NT on stack if buffer.size() > 0 and num_open_NT >= 1: valid_actions.append(self.shift_idx) return valid_actions
def forward(self, inputs, fhiddens, bhiddens, fstates, bstates): outputs = jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], []) outputf = inputs outputb = inputs i = 0 for layer1, layer2, dropout1, dropout2 in zip(self.forward_model, self.backward_model, self.dropoutf, self.dropoutb): fstate = fstates[i] bstate = bstates[i] fhidden = fhiddens[i] bhidden = bhiddens[i] outputf, fstate = layer1(outputf, fstate, fhidden) outputb, bstate = layer2(outputb, bstate, bhidden) outputf = dropout1(outputf) outputb = dropout2(outputb) i += 1 outputs += [(fstate, bstate)] return torch.cat((outputf, outputb), -1), outputs