def _init_problems(self, problems_path):
        with open(problems_path) as f:
            essays = json.load(f)

        self._predefined_problems, problem_embedded_texts = [], []
        for essay in essays:
            for hint in essay['hints']:
                problem = fix_spaces(hint['problem']).strip(punctuation + ' ')
                problem = problem.split(' ', 1)[1]
                comment = fix_spaces(hint['comment']).strip(punctuation + ' ')
                if len(comment) < 5:
                    comment = ''
                hint = {
                    'theme': hint['theme'].strip(),
                    'comment': comment,
                    'problem': problem,
                    'author_position': fix_spaces(hint['author_position']).strip(punctuation + ' ') + '.',
                }
                problem_embedded_texts.append(self._get_hint_text(hint))
                self._predefined_problems.append({
                    key: value[0].lower() + value[1:] if value else ''
                    for key, value in hint.items()
                })

        self._problem_embeddings = self._session.run(
            self._embedded_text, feed_dict={self._text_input: problem_embedded_texts}
        )
        assert len(self._predefined_problems) == len(problem_embedded_texts) == len(self._problem_embeddings)
Exemple #2
0
    def predict_from_model(self, task):
        choices, conditions = self.parse_task(task)
        examples = []
        for choice in choices:
            text = fix_spaces(choice['text'])
            tokenization = self._tokenizer.encode_plus(text, add_special_tokens=True)

            examples.append(Example(
                text=text,
                token_ids=tokenization['input_ids'],
                segment_ids=tokenization['token_type_ids'],
                mask=[1] * len(tokenization['input_ids']),
                error_type=None,
                error_type_id=None
            ))

        model_inputs = self._batch_collector(examples)

        target_condition_indices = [_ERRORS.index(condition) for condition in conditions]

        pred_dict = {}
        with torch.no_grad():
            model_prediction = self._model(model_inputs)
            model_prediction = model_prediction['error_type_logits'][:, target_condition_indices]
            model_prediction = model_prediction.argmax(0)

            for i, condition in enumerate(task['question']['left']):
                pred_dict[condition['id']] = choices[model_prediction[i]]['id']

        return pred_dict
Exemple #3
0
    def _get_examples_from_task(self, task):
        choices, conditions = self.parse_task(task)
        if 'correct_variants' in task['solution']:
            answers = task['solution']['correct_variants'][0]
        else:
            answers = task['solution']['correct']

        choice_index_to_error_type = {
            int(answers[option]) - 1: conditions[option_index]
            for option_index, option in enumerate(sorted(answers))
        }
        choice_index_to_error_type = {
            choice_index: choice_index_to_error_type.get(choice_index, _ERRORS[-1])
            for choice_index, choice in enumerate(choices)
        }

        assert len(answers) == sum(1 for error_type in choice_index_to_error_type.values() if error_type != _ERRORS[-1])

        for choice_index, choice in enumerate(choices):
            error_type = choice_index_to_error_type[choice_index]

            text = fix_spaces(choices[choice_index]['text'])

            tokenization = self._tokenizer.encode_plus(text, add_special_tokens=True)
            assert len(tokenization['input_ids']) == len(tokenization['token_type_ids'])

            yield Example(
                text=text,
                token_ids=tokenization['input_ids'],
                segment_ids=tokenization['token_type_ids'],
                mask=[1] * len(tokenization['input_ids']),
                error_type=error_type,
                error_type_id=_ERRORS.index(error_type)
            )
Exemple #4
0
    def build(cls, text, tokenizer, text_pair=None):
        text = fix_spaces(text)
        if text_pair:
            text_pair = fix_spaces(text_pair)

        tokenization = tokenizer.encode_plus(text,
                                             text_pair=text_pair,
                                             add_special_tokens=True)
        assert len(tokenization['input_ids']) == len(
            tokenization['token_type_ids'])

        return cls(text=text,
                   token_ids=tokenization['input_ids'],
                   segment_ids=tokenization['token_type_ids'],
                   mask=[1] * len(tokenization['input_ids']),
                   text_pair=text_pair)
    def _init_arguments(self, arguments_path):
        with open(arguments_path) as f:
            arguments = json.load(f)

        self._predefined_arguments, argument_embedded_texts = [], []
        for argument in arguments:
            for example in argument['examples']:
                example = fix_spaces(example).strip(punctuation + ' ') + '.'

                argument_embedded_text = argument['name'].strip(punctuation + ' ') + '.'
                argument_embedded_text = argument_embedded_text[0].upper() + argument_embedded_text[1:]
                argument_embedded_text += ' ' + example

                self._predefined_arguments.append(example)
                argument_embedded_texts.append(argument_embedded_text)

        self._argument_embeddings = self._session.run(
            self._embedded_text, feed_dict={self._text_input: argument_embedded_texts}
        )

        assert len(self._predefined_arguments) == len(argument_embedded_texts) == len(self._argument_embeddings)
Exemple #6
0
 def _get_sentences(text):
     for sent in text.split('.'):
         if re.search('\(\d+\)', sent):
             yield fix_spaces(sent.strip()) + '.'
 def _get_sentences(text):
     for sent in text.split("."):
         if re.search("\(\d+\)", sent):
             yield fix_spaces(sent.strip()) + "."