예제 #1
0
    def _run_prediction_batch(self, batch):
        pred_sentences = []
        gold_sentences = []

        sentence_features, sentences = batch

        sentence_features = torch.squeeze(sentence_features, dim=1)

        sentence_lengths = [len(sentence) for sentence in sentences]
        sentence_lengths = torch.tensor(
            sentence_lengths, dtype=torch.int64, device=self.device
        )

        contextualized_tokens_batch = self.model.compute_lstm_output(
            sentence_features, sentence_lengths
        )

        configurations = [
            Configuration(
                sentence,
                contextualized_input,
                self.model,
                sentence_features=sentence_feature,
            )
            for contextualized_input, sentence, sentence_feature in zip(
                contextualized_tokens_batch, sentences, sentence_features
            )
        ]

        while configurations:
            # Pass the stacks and buffers through the MLPs in one batch
            configurations, _, _ = self._update_classification_scores(configurations)

            # The actual computation of the loss must be done sequentially
            for configuration in configurations:
                # Predict a list of possible actions: Transitions, their
                # label (if the transition is LEFT/ RIGHT_ARC) and the
                # score of the action based on the MLP output.
                actions = configuration.predict_actions()

                if not configuration.swap_possible:
                    # Exclude swap options
                    actions = [action for action in actions if action.transition != T.SWAP]

                assert actions
                best_action = Configuration.get_best_action(actions)

                if best_action.transition == T.SWAP:
                    configuration.num_swap += 1

                configuration.apply_transition(best_action)

                if configuration.is_terminal:
                    pred_sentences.append(configuration.predicted_sentence)
                    gold_sentences.append(configuration.sentence)

            # Remove all finished configurations
            configurations = [c for c in configurations if not c.is_terminal]

        return pred_sentences, gold_sentences
예제 #2
0
    def generate_configuration(self, sentence: Sentence) -> ConfigurationItem:
        configuration = Configuration(sentence,
                                      contextualized_input=None,
                                      model=None,
                                      device=self.device)

        item_filter = set()

        for configuration_item in self._generate_next_datapoint(configuration):
            feature_key = (
                tuple(configuration_item.stack.cpu().tolist()),
                tuple(configuration_item.buffer.cpu().tolist()),
            )

            if feature_key not in item_filter:
                item_filter.add(feature_key)
                yield configuration_item
예제 #3
0
    def _process_training_batch(
        self,
        batch: Tuple[torch.Tensor, List[Sentence]],
        error_probability: float,
        margin_threshold: float,
        criterion: Optional[_Loss] = None,
    ) -> List[torch.Tensor]:
        """
        Parses the sentences in the given batch and returns the loss values.
        """
        loss = []

        transition_logits = []
        transition_gold_labels = []

        relation_logits = []
        relation_gold_labels = []

        # The batch contains of a tensor of processed sentences
        # that are ready to be used as an input to the LSTM
        # and the corresponding sentence objects that are needed
        # to grade the performance of the predictions.
        sentence_features, sentences = batch

        # Run the sentence through the encoder to get the outputs.
        # These outputs will stay the same for the sentence,
        # so we compute them once in the beginning.
        token_sequences = sentence_features[:, 0, :]

        sentence_lengths = to_int_tensor(
            data=[len(sentence) for sentence in sentences],
            device=self.model.device)

        contextualized_tokens_batch = self.model.get_contextualized_input(
            token_sequences, sentence_lengths)

        # Create the initial configurations for all sentences in the batch
        configurations = [
            Configuration(
                sentence,
                contextualized_input,
                self.model,
                sentence_features=sentence_feature,
            ) for contextualized_input, sentence_feature, sentence in zip(
                contextualized_tokens_batch, sentence_features, sentences)
        ]

        # Main loop for the sentences in this batch
        while configurations:
            # Remove all finished configurations
            configurations = [c for c in configurations if not c.is_terminal]
            if not configurations:
                break

            # Pass the stacks and buffers through the MLPs in one batch
            configurations = self.predict_logits(configurations, self.model)

            # The actual computation of the loss must be done sequentially
            for configuration in configurations:
                # Predict a list of possible actions: Transitions, their
                # label (if the transition is LEFT/ RIGHT_ARC) and the
                # score of the action based on the MLP output.
                actions = configuration.predict_actions()

                # Calculate the 'costs' for each action. These determine
                # which action should be performed based on the given
                # conf
                costs, shift_case = configuration.get_transition_costs(actions)

                # Compute the best valid and the best wrong action,
                # where the latter on is a transition that is technically
                # possible, but would introduce an error compared to the
                # gold tree. To keep the model robust, we sometimes
                # decided, however, to use it instead of the valid one.
                best_action, best_valid_action, best_wrong_action = configuration.select_actions(
                    actions, costs, error_probability, margin_threshold)

                # Apply the dynamic oracle to update the sentence structure
                # for the case that the chosen action does not exactly
                # follow the gold tree.
                configuration.update_dynamic_oracle(best_action, shift_case)

                # Apply the best action and update the stack and buffer
                configuration.apply_transition(best_action)

                if criterion:
                    # Compute CrossEntropy loss
                    gold_transition = best_valid_action.transition.value
                    gold_relation = self.model.relations.label_signature.get_id(
                        (best_valid_action.transition,
                         best_valid_action.relation))

                    transition_logits.append(
                        configuration.scores["transition_logits"])
                    relation_logits.append(
                        configuration.scores["relation_logits"])

                    transition_gold_labels.append(gold_transition)
                    relation_gold_labels.append(gold_relation)
                else:
                    # Compute the loss by using the margin between the scores
                    if (best_wrong_action.transition is not None
                            and best_valid_action.np_score <
                            best_wrong_action.np_score + margin_threshold):
                        margin = best_wrong_action.score - best_valid_action.score
                        loss.append(margin)

        if criterion:
            transition_logits = torch.stack(transition_logits)
            relation_logits = torch.stack(relation_logits)

            transition_gold_labels = to_int_tensor(transition_gold_labels,
                                                   self.model.device)
            relation_gold_labels = to_int_tensor(relation_gold_labels,
                                                 self.model.device)

            transition_loss = criterion(transition_logits,
                                        transition_gold_labels)
            relation_loss = criterion(relation_logits, relation_gold_labels)

            return transition_loss + relation_loss
        return loss
예제 #4
0
    def _generate_next_datapoint(self, configuration):
        if not configuration.is_terminal:
            stack = configuration.stack_tensor
            buffer = configuration.buffer_tensor

            possible_actions = list(self._get_possible_action(configuration))
            costs, shift_case = configuration.get_transition_costs(
                possible_actions)

            valid_actions = configuration.get_valid_actions(
                possible_actions, costs)
            wrong_actions = configuration.get_wrong_actions(
                possible_actions, costs)

            if valid_actions:
                actions = [("valid", choice(valid_actions))]

                if random() < self.error_probability and costs[T.SWAP] != 0:
                    selected_wrong_actions = self._remove_label_duplicates(
                        wrong_actions)

                    transitions = set([a.transition for a in valid_actions])
                    selected_wrong_actions = [
                        a for a in selected_wrong_actions
                        if a.transition != T.SWAP
                        and a.transition not in transitions
                    ]

                    if selected_wrong_actions:
                        wrong_action = choice(selected_wrong_actions)
                        actions.append(("wrong", wrong_action))

                shuffle(actions)
                for i, (source, action) in enumerate(actions):
                    if len(actions) == 1 or i == len(actions) - 1:
                        # If this the only / last action, reuse the existing
                        # configuration to avoid the deepcopy overhead.
                        new_config = configuration
                    else:
                        new_config = Configuration(
                            deepcopy(configuration.sentence),
                            None,
                            None,
                            False,
                            configuration.device,
                        )
                        new_config.buffer = deepcopy(configuration.buffer)
                        new_config.stack = deepcopy(configuration.stack)

                    new_config.update_dynamic_oracle(action, shift_case)
                    new_config.apply_transition(action)

                    gold_transition, gold_relation = self._get_gold_labels(
                        action)

                    if source == "valid":
                        wrong_transitions_tensor, wrong_relations_tensor = self._get_all_labels(
                            wrong_actions)

                        yield ConfigurationItem(
                            sentence=self._get_sentence_tensor(
                                new_config.sentence),
                            stack=stack,
                            buffer=buffer,
                            gold_transition=gold_transition,
                            gold_relation=gold_relation,
                            wrong_transitions=wrong_transitions_tensor,
                            wrong_relations=wrong_relations_tensor,
                        )

                    for configuration_item in self._generate_next_datapoint(
                            new_config):
                        yield configuration_item
예제 #5
0
 def __init__(self, sentence):
     self.sentence = deepcopy(sentence)
     self.configuration = Configuration(sentence=self.sentence,
                                        contextualized_input=None,
                                        model=None)