def _run_prediction_batch(self, batch): pred_sentences = [] gold_sentences = [] sentence_features, sentences = batch sentence_features = torch.squeeze(sentence_features, dim=1) sentence_lengths = [len(sentence) for sentence in sentences] sentence_lengths = torch.tensor( sentence_lengths, dtype=torch.int64, device=self.device ) contextualized_tokens_batch = self.model.compute_lstm_output( sentence_features, sentence_lengths ) configurations = [ Configuration( sentence, contextualized_input, self.model, sentence_features=sentence_feature, ) for contextualized_input, sentence, sentence_feature in zip( contextualized_tokens_batch, sentences, sentence_features ) ] while configurations: # Pass the stacks and buffers through the MLPs in one batch configurations, _, _ = self._update_classification_scores(configurations) # The actual computation of the loss must be done sequentially for configuration in configurations: # Predict a list of possible actions: Transitions, their # label (if the transition is LEFT/ RIGHT_ARC) and the # score of the action based on the MLP output. actions = configuration.predict_actions() if not configuration.swap_possible: # Exclude swap options actions = [action for action in actions if action.transition != T.SWAP] assert actions best_action = Configuration.get_best_action(actions) if best_action.transition == T.SWAP: configuration.num_swap += 1 configuration.apply_transition(best_action) if configuration.is_terminal: pred_sentences.append(configuration.predicted_sentence) gold_sentences.append(configuration.sentence) # Remove all finished configurations configurations = [c for c in configurations if not c.is_terminal] return pred_sentences, gold_sentences
def generate_configuration(self, sentence: Sentence) -> ConfigurationItem: configuration = Configuration(sentence, contextualized_input=None, model=None, device=self.device) item_filter = set() for configuration_item in self._generate_next_datapoint(configuration): feature_key = ( tuple(configuration_item.stack.cpu().tolist()), tuple(configuration_item.buffer.cpu().tolist()), ) if feature_key not in item_filter: item_filter.add(feature_key) yield configuration_item
def _process_training_batch( self, batch: Tuple[torch.Tensor, List[Sentence]], error_probability: float, margin_threshold: float, criterion: Optional[_Loss] = None, ) -> List[torch.Tensor]: """ Parses the sentences in the given batch and returns the loss values. """ loss = [] transition_logits = [] transition_gold_labels = [] relation_logits = [] relation_gold_labels = [] # The batch contains of a tensor of processed sentences # that are ready to be used as an input to the LSTM # and the corresponding sentence objects that are needed # to grade the performance of the predictions. sentence_features, sentences = batch # Run the sentence through the encoder to get the outputs. # These outputs will stay the same for the sentence, # so we compute them once in the beginning. token_sequences = sentence_features[:, 0, :] sentence_lengths = to_int_tensor( data=[len(sentence) for sentence in sentences], device=self.model.device) contextualized_tokens_batch = self.model.get_contextualized_input( token_sequences, sentence_lengths) # Create the initial configurations for all sentences in the batch configurations = [ Configuration( sentence, contextualized_input, self.model, sentence_features=sentence_feature, ) for contextualized_input, sentence_feature, sentence in zip( contextualized_tokens_batch, sentence_features, sentences) ] # Main loop for the sentences in this batch while configurations: # Remove all finished configurations configurations = [c for c in configurations if not c.is_terminal] if not configurations: break # Pass the stacks and buffers through the MLPs in one batch configurations = self.predict_logits(configurations, self.model) # The actual computation of the loss must be done sequentially for configuration in configurations: # Predict a list of possible actions: Transitions, their # label (if the transition is LEFT/ RIGHT_ARC) and the # score of the action based on the MLP output. actions = configuration.predict_actions() # Calculate the 'costs' for each action. These determine # which action should be performed based on the given # conf costs, shift_case = configuration.get_transition_costs(actions) # Compute the best valid and the best wrong action, # where the latter on is a transition that is technically # possible, but would introduce an error compared to the # gold tree. To keep the model robust, we sometimes # decided, however, to use it instead of the valid one. best_action, best_valid_action, best_wrong_action = configuration.select_actions( actions, costs, error_probability, margin_threshold) # Apply the dynamic oracle to update the sentence structure # for the case that the chosen action does not exactly # follow the gold tree. configuration.update_dynamic_oracle(best_action, shift_case) # Apply the best action and update the stack and buffer configuration.apply_transition(best_action) if criterion: # Compute CrossEntropy loss gold_transition = best_valid_action.transition.value gold_relation = self.model.relations.label_signature.get_id( (best_valid_action.transition, best_valid_action.relation)) transition_logits.append( configuration.scores["transition_logits"]) relation_logits.append( configuration.scores["relation_logits"]) transition_gold_labels.append(gold_transition) relation_gold_labels.append(gold_relation) else: # Compute the loss by using the margin between the scores if (best_wrong_action.transition is not None and best_valid_action.np_score < best_wrong_action.np_score + margin_threshold): margin = best_wrong_action.score - best_valid_action.score loss.append(margin) if criterion: transition_logits = torch.stack(transition_logits) relation_logits = torch.stack(relation_logits) transition_gold_labels = to_int_tensor(transition_gold_labels, self.model.device) relation_gold_labels = to_int_tensor(relation_gold_labels, self.model.device) transition_loss = criterion(transition_logits, transition_gold_labels) relation_loss = criterion(relation_logits, relation_gold_labels) return transition_loss + relation_loss return loss
def _generate_next_datapoint(self, configuration): if not configuration.is_terminal: stack = configuration.stack_tensor buffer = configuration.buffer_tensor possible_actions = list(self._get_possible_action(configuration)) costs, shift_case = configuration.get_transition_costs( possible_actions) valid_actions = configuration.get_valid_actions( possible_actions, costs) wrong_actions = configuration.get_wrong_actions( possible_actions, costs) if valid_actions: actions = [("valid", choice(valid_actions))] if random() < self.error_probability and costs[T.SWAP] != 0: selected_wrong_actions = self._remove_label_duplicates( wrong_actions) transitions = set([a.transition for a in valid_actions]) selected_wrong_actions = [ a for a in selected_wrong_actions if a.transition != T.SWAP and a.transition not in transitions ] if selected_wrong_actions: wrong_action = choice(selected_wrong_actions) actions.append(("wrong", wrong_action)) shuffle(actions) for i, (source, action) in enumerate(actions): if len(actions) == 1 or i == len(actions) - 1: # If this the only / last action, reuse the existing # configuration to avoid the deepcopy overhead. new_config = configuration else: new_config = Configuration( deepcopy(configuration.sentence), None, None, False, configuration.device, ) new_config.buffer = deepcopy(configuration.buffer) new_config.stack = deepcopy(configuration.stack) new_config.update_dynamic_oracle(action, shift_case) new_config.apply_transition(action) gold_transition, gold_relation = self._get_gold_labels( action) if source == "valid": wrong_transitions_tensor, wrong_relations_tensor = self._get_all_labels( wrong_actions) yield ConfigurationItem( sentence=self._get_sentence_tensor( new_config.sentence), stack=stack, buffer=buffer, gold_transition=gold_transition, gold_relation=gold_relation, wrong_transitions=wrong_transitions_tensor, wrong_relations=wrong_relations_tensor, ) for configuration_item in self._generate_next_datapoint( new_config): yield configuration_item
def __init__(self, sentence): self.sentence = deepcopy(sentence) self.configuration = Configuration(sentence=self.sentence, contextualized_input=None, model=None)