示例#1
0
文件: main.py 项目: ivanmkc/gector
def predict(request: GrammarRequest, model: GecBERTModel = Depends(get_model)):
    # return model.predict(request)
    preds, _ = model.handle_batch([request.text.split()])

    if len(preds) > 0:
        result = ""
        for word in preds[0]:
            remove_characters = ["'", '"', "(", ")", "[", "]"]

            for character in remove_characters:
                word = word.replace(character, "")

            first_char = word[0]
            if first_char.isalnum():
                result += " "

            result += word

        return GrammarResponse(result=result.strip())
    else:
        return GrammarResponse(result="")
示例#2
0
class GectorBertModel(lit_model.Model):
    ATTENTION_LAYERS = 12
    ATTENTION_HEADS = 12
    MAX_LEN = 50

    def __init__(self, model_path):
        self.model = GecBERTModel(
            vocab_path='data/output_vocabulary',
            model_paths=[model_path],
            max_len=GectorBertModel.MAX_LEN,
            min_len=3,
            iterations=
            1,  # limit to 1 iteration to make attention analysis reasonable
            min_error_probability=0.0,
            model_name='bert',  # we're using BERT
            special_tokens_fix=0,  # disabled for BERT
            log=False,
            confidence=0,
            is_ensemble=0,
            weigths=None)

    # LIT API implementation
    def max_minibatch_size(self, config: Any = None) -> int:
        return 32

    def predict_minibatch(self, inputs: List, config=None) -> List:
        # we append '$START' to the beginning of token lists because GECTOR does as well (and this is what BERT
        # ends up processing. see preprocess() and postprocess_batch() in gec_model

        # this breaks down if we have duplicates, but we shouldn't
        sentence_indices = [(ex["input_text"], index)
                            for index, ex in enumerate(inputs)]
        tokenized_input_with_indices = [
            (original.split(), index) for original, index in sentence_indices
        ]
        batch = [(tokens, index)
                 for tokens, index in tokenized_input_with_indices
                 if len(tokens) >= self.model.min_len]

        # anything under min length doesn't get processed anyway, so we don't pass it in and just keep it
        # so we can put stuff back in order later
        ignored = [(tokens, index)
                   for tokens, index in tokenized_input_with_indices
                   if len(tokens) < self.model.min_len]

        model_input = [tokens for tokens, index in batch]
        predictions, _, attention = self.model.handle_batch(model_input)
        attention = attention[0]  # we only have one iteration

        assert (len(predictions) == len(attention))
        output = [{
            'predicted': ' '.join(tokenlist)
        } for tokenlist in predictions]

        # wanted to average across heads and layers, but attention with different head counts breaks LIT
        #attention_averaged = numpy.average(attention, (1, 2))[:, numpy.newaxis, ...]

        batch_iter = iter(batch)
        for output_dict, attention_info in zip(output, attention):
            original_tokens, original_index = next(batch_iter)
            output_dict['original_index'] = original_index
            output_dict['layer_average'] = numpy.average(attention_info,
                                                         axis=0)

            for layer_index, attention_layer_info in enumerate(attention_info):
                output_dict['layer{}'.format(
                    layer_index)] = attention_layer_info

        output.extend({
            'predicted': ' '.join(tokens),
            'original_index': original_index
        } for tokens, original_index in ignored)
        output.sort(key=lambda x: x['original_index'])
        for tokenized_input, index in tokenized_input_with_indices:
            output[index]['input_tokens'] = ['$START'] + tokenized_input

        for d in output:
            del d['original_index']

        return output

    def input_spec(self) -> lit_types.Spec:
        return {
            "input_text": lit_types.TextSegment(),
            "target_text": lit_types.TextSegment()
        }

    def output_spec(self) -> lit_types.Spec:
        output = {
            "input_tokens":
            lit_types.Tokens(parent="input_text"),
            "predicted":
            lit_types.GeneratedText(parent='target_text'),
            'layer_average':
            lit_types.AttentionHeads(align=('input_tokens', 'input_tokens'))
        }
        for layer in range(self.ATTENTION_LAYERS):
            output['layer{}'.format(layer)] = lit_types.AttentionHeads(
                align=('input_tokens', 'input_tokens'))

        return output
示例#3
0
class GrammarCorrection:
    def __init__(self):
        """
            vocab_path: Path to the vocab file.
            model_paths: Path to the model file.
            max_len: The max sentence length(all longer will be truncated). type->int
            min_len: The minimum sentence length(all longer will be returned w/o changes) type-> int
            iterations: The number of iterations of the model. type->int
            min_error_probability: Minimum probability for each action to apply.
                                   Also, minimum error probability, as described in the paper. type->float
            lowercase_tokens: The number of iterations of the model. type->int
            model_name: Name of the transformer model.you can chose['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert']
            special_tokens_fix: Whether to fix problem with [CLS], [SEP] tokens tokenization.
                                For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa. type->int
            confidence: How many probability to add to $KEEP token. type->float
            is_ensemble: Whether to do ensembling. type->int
            weigths: Used to calculate weighted average.
            batch_size: The size of hidden unit cell.
            :return:
            """

        if not os.path.exists('./model/bert_0_gector.th'):
            self.download_model()

        self.model = GecBERTModel(vocab_path='./data/output_vocabulary',
                                  model_paths=['./model/bert_0_gector.th'],
                                  max_len=50, min_len=3,
                                  iterations=5,
                                  min_error_probability=0.0,
                                  lowercase_tokens=0,
                                  model_name='bert',
                                  special_tokens_fix=0,
                                  log=False,
                                  confidence=0,
                                  is_ensemble=0,
                                  weigths=None)

    @staticmethod
    def download_model():
        link = "http://imdreamer.oss-cn-hangzhou.aliyuncs.com/bert_0_gector.th"
        with open('./model/bert_0_gector.th', "wb") as f:
            print('Downloading grammatical error correction model [bert_0_gector]! Please wait a minute!')
            response = requests.get(link, stream=True)
            total_length = response.headers.get('content-length')

            if total_length is None:  # no content length header
                f.write(response.content)
            else:
                dl = 0
                total_length = int(total_length)
                for data in response.iter_content(chunk_size=4096):
                    dl += len(data)
                    f.write(data)
                    done = int(50 * dl / total_length)
                    sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
                    sys.stdout.flush()

    @staticmethod
    def language_checker(text):
        matches = tool.check(text)
        correct_text = lc.correct(text, matches)
        return correct_text

    def correct_sentence_by_file(self, input_file='./data/predict_for_file/input.txt',
                                 output_file='./data/predict_for_file/output.txt', batch_size=32):
        test_data = read_lines(input_file)
        predictions = []
        cnt_corrections = 0
        batch = []
        for sent in test_data:
            batch.append(self.language_checker(sent).split())
            if len(batch) == batch_size:
                preds, cnt = self.model.handle_batch(batch)
                predictions.extend(preds)
                cnt_corrections += cnt
                batch = []
        if batch:
            preds, cnt = self.model.handle_batch(batch)
            predictions.extend(preds)
            cnt_corrections += cnt

        with open(output_file, 'w') as f:
            f.write("\n".join([" ".join(x) for x in predictions]) + '\n')
        return cnt_corrections

    def correct_sentence(self, input_string):
        predictions = []
        cnt_corrections = 0
        batch = [self.language_checker(input_string).split()]
        if batch:
            preds, cnt = self.model.handle_batch(batch)
            predictions.extend(preds)
            cnt_corrections += cnt
        return [" ".join(x) for x in predictions][0]