Ejemplo n.º 1
0
    def __init__(self):
        """
            vocab_path: Path to the vocab file.
            model_paths: Path to the model file.
            max_len: The max sentence length(all longer will be truncated). type->int
            min_len: The minimum sentence length(all longer will be returned w/o changes) type-> int
            iterations: The number of iterations of the model. type->int
            min_error_probability: Minimum probability for each action to apply.
                                   Also, minimum error probability, as described in the paper. type->float
            lowercase_tokens: The number of iterations of the model. type->int
            model_name: Name of the transformer model.you can chose['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert']
            special_tokens_fix: Whether to fix problem with [CLS], [SEP] tokens tokenization.
                                For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa. type->int
            confidence: How many probability to add to $KEEP token. type->float
            is_ensemble: Whether to do ensembling. type->int
            weigths: Used to calculate weighted average.
            batch_size: The size of hidden unit cell.
            :return:
            """

        if not os.path.exists('./model/bert_0_gector.th'):
            self.download_model()

        self.model = GecBERTModel(vocab_path='./data/output_vocabulary',
                                  model_paths=['./model/bert_0_gector.th'],
                                  max_len=50, min_len=3,
                                  iterations=5,
                                  min_error_probability=0.0,
                                  lowercase_tokens=0,
                                  model_name='bert',
                                  special_tokens_fix=0,
                                  log=False,
                                  confidence=0,
                                  is_ensemble=0,
                                  weigths=None)
Ejemplo n.º 2
0
 def __init__(self, model_path):
     self.model = GecBERTModel(
         vocab_path='data/output_vocabulary',
         model_paths=[model_path],
         max_len=GectorBertModel.MAX_LEN,
         min_len=3,
         iterations=
         1,  # limit to 1 iteration to make attention analysis reasonable
         min_error_probability=0.0,
         model_name='bert',  # we're using BERT
         special_tokens_fix=0,  # disabled for BERT
         log=False,
         confidence=0,
         is_ensemble=0,
         weigths=None)
def main(args):
    # get all paths
    #     if args.count_thread != -1:
    #         torch.set_num_threads = str(args.count_thread)
    #         os.environ["OMP_NUM_THREADS"] = str(args.count_thread)
    #         os.environ["MKL_NUM_THREADS"] = str(args.count_thread)

    if args.cuda_device_index != -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device_index)
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

    model = GecBERTModel(vocab_path=args.vocab_path,
                         model_paths=args.model_path,
                         max_len=args.max_len,
                         min_len=args.min_len,
                         iterations=args.iteration_count,
                         min_error_probability=args.min_error_probability,
                         min_probability=args.min_error_probability,
                         lowercase_tokens=args.lowercase_tokens,
                         model_name=args.transformer_model,
                         special_tokens_fix=args.special_tokens_fix,
                         log=False,
                         confidence=args.additional_confidence,
                         is_ensemble=args.is_ensemble,
                         weigths=args.weights,
                         use_cpu=bool(args.use_cpu))

    cnt_corrections = predict_for_file(args.input_file,
                                       args.output_file,
                                       model,
                                       batch_size=args.batch_size,
                                       save_logs=args.save_logs)
    # evaluate with m2 or ERRANT
    print(f"Produced overall corrections: {cnt_corrections}")
Ejemplo n.º 4
0
def main(args):
    # get all paths
    model = GecBERTModel(vocab_path=args.vocab_path,
                         model_paths=args.model_path,
                         max_len=args.max_len, min_len=args.min_len,
                         iterations=args.iteration_count,
                         min_error_probability=args.min_error_probability,
                         min_probability=args.min_error_probability,
                         lowercase_tokens=args.lowercase_tokens,
                         model_name=args.transformer_model,
                         special_tokens_fix=args.special_tokens_fix,
                         log=False,
                         confidence=args.additional_confidence,
                         is_ensemble=args.is_ensemble,
                         weigths=args.weights)
# 模型运行即可.
    cnt_corrections,wenben1,wenben2 = predict_for_file(args.input_file, args.output_file, model,
                                       batch_size=args.batch_size)
    '''
    "explain": 纠错说明, 
       "location": 错误单词位置,
         "sensitive": 错误文本,
      "expect": 推荐文本, 
      "level": 错误级别(0-2 越大越严重), 
      "errtype": 错误类型,
       'shortsentence'
    '''
    error_inform=[]
    wenben3=[i.split(' ') for i in wenben1]
    for i in range(len(wenben3)):
        error=[]
        tmp1=wenben3[i]
        tmp2=wenben2[i]
        for j in range(min(len(tmp1),len(tmp2))):
            if tmp1[j]!=tmp2[j]:
                error.append({
                    'explain':'gec',
                    'location':j,
                    'sensitive':tmp1[j],
                    'expect':tmp2[j],
                    'level':1,
                    'errtype':'gec',
    'shortsentence':tmp1[j-1:j+1]
                })
        error_inform.append(error)
    print(error_inform)





    # evaluate with m2 or ERRANT
    print(f"Produced overall corrections: {cnt_corrections}")
    print("都预测完毕")
Ejemplo n.º 5
0
def load_model():
    model = GecBERTModel(
        vocab_path=vocab_path,
        model_paths=[model_path],
        min_error_probability=0.66,
        model_name='xlnet',
        max_len=50,
        min_len=3,
        iterations=5,
        min_probability=0.0,
        lowercase_tokens=0,
        special_tokens_fix=0,
        confidence=0.0,  #keep it zero on cpu
        is_ensemble=0,
        weigths=None)
    return model
Ejemplo n.º 6
0
def load_model():
    model = GecBERTModel(vocab_path='gector/data/output_vocabulary',
                         model_paths=[
                             "gector/model_path/bert_0_gector.th",
                             "gector/model_path/roberta_1_gector.th",
                             "gector/model_path/xlnet_0_gector.th"
                         ],
                         max_len=50,
                         min_len=3,
                         iterations=5,
                         min_error_probability=0,
                         min_probability=0,
                         lowercase_tokens=0,
                         model_name=['bert', 'roberta', 'xlnet'],
                         special_tokens_fix=1,
                         log=False,
                         confidence=0,
                         is_ensemble=1)
    return model
Ejemplo n.º 7
0
def main(args):
    # get all paths
    model = GecBERTModel(vocab_path=args.vocab_path,
                         model_paths=args.model_path,
                         max_len=args.max_len, min_len=args.min_len,
                         iterations=args.iteration_count,
                         min_error_probability=args.min_error_probability,
                         min_probability=args.min_error_probability,
                         lowercase_tokens=args.lowercase_tokens,
                         model_name=args.transformer_model,
                         special_tokens_fix=args.special_tokens_fix,
                         log=False,
                         confidence=args.additional_confidence,
                         is_ensemble=args.is_ensemble,
                         weigths=args.weights)

    cnt_corrections = predict_for_file(args.input_file, args.output_file, model,
                                       batch_size=args.batch_size)
    # evaluate with m2 or ERRANT
    print(f"Produced overall corrections: {cnt_corrections}")
Ejemplo n.º 8
0
def predict(request: GrammarRequest, model: GecBERTModel = Depends(get_model)):
    # return model.predict(request)
    preds, _ = model.handle_batch([request.text.split()])

    if len(preds) > 0:
        result = ""
        for word in preds[0]:
            remove_characters = ["'", '"', "(", ")", "[", "]"]

            for character in remove_characters:
                word = word.replace(character, "")

            first_char = word[0]
            if first_char.isalnum():
                result += " "

            result += word

        return GrammarResponse(result=result.strip())
    else:
        return GrammarResponse(result="")
Ejemplo n.º 9
0
class GectorBertModel(lit_model.Model):
    ATTENTION_LAYERS = 12
    ATTENTION_HEADS = 12
    MAX_LEN = 50

    def __init__(self, model_path):
        self.model = GecBERTModel(
            vocab_path='data/output_vocabulary',
            model_paths=[model_path],
            max_len=GectorBertModel.MAX_LEN,
            min_len=3,
            iterations=
            1,  # limit to 1 iteration to make attention analysis reasonable
            min_error_probability=0.0,
            model_name='bert',  # we're using BERT
            special_tokens_fix=0,  # disabled for BERT
            log=False,
            confidence=0,
            is_ensemble=0,
            weigths=None)

    # LIT API implementation
    def max_minibatch_size(self, config: Any = None) -> int:
        return 32

    def predict_minibatch(self, inputs: List, config=None) -> List:
        # we append '$START' to the beginning of token lists because GECTOR does as well (and this is what BERT
        # ends up processing. see preprocess() and postprocess_batch() in gec_model

        # this breaks down if we have duplicates, but we shouldn't
        sentence_indices = [(ex["input_text"], index)
                            for index, ex in enumerate(inputs)]
        tokenized_input_with_indices = [
            (original.split(), index) for original, index in sentence_indices
        ]
        batch = [(tokens, index)
                 for tokens, index in tokenized_input_with_indices
                 if len(tokens) >= self.model.min_len]

        # anything under min length doesn't get processed anyway, so we don't pass it in and just keep it
        # so we can put stuff back in order later
        ignored = [(tokens, index)
                   for tokens, index in tokenized_input_with_indices
                   if len(tokens) < self.model.min_len]

        model_input = [tokens for tokens, index in batch]
        predictions, _, attention = self.model.handle_batch(model_input)
        attention = attention[0]  # we only have one iteration

        assert (len(predictions) == len(attention))
        output = [{
            'predicted': ' '.join(tokenlist)
        } for tokenlist in predictions]

        # wanted to average across heads and layers, but attention with different head counts breaks LIT
        #attention_averaged = numpy.average(attention, (1, 2))[:, numpy.newaxis, ...]

        batch_iter = iter(batch)
        for output_dict, attention_info in zip(output, attention):
            original_tokens, original_index = next(batch_iter)
            output_dict['original_index'] = original_index
            output_dict['layer_average'] = numpy.average(attention_info,
                                                         axis=0)

            for layer_index, attention_layer_info in enumerate(attention_info):
                output_dict['layer{}'.format(
                    layer_index)] = attention_layer_info

        output.extend({
            'predicted': ' '.join(tokens),
            'original_index': original_index
        } for tokens, original_index in ignored)
        output.sort(key=lambda x: x['original_index'])
        for tokenized_input, index in tokenized_input_with_indices:
            output[index]['input_tokens'] = ['$START'] + tokenized_input

        for d in output:
            del d['original_index']

        return output

    def input_spec(self) -> lit_types.Spec:
        return {
            "input_text": lit_types.TextSegment(),
            "target_text": lit_types.TextSegment()
        }

    def output_spec(self) -> lit_types.Spec:
        output = {
            "input_tokens":
            lit_types.Tokens(parent="input_text"),
            "predicted":
            lit_types.GeneratedText(parent='target_text'),
            'layer_average':
            lit_types.AttentionHeads(align=('input_tokens', 'input_tokens'))
        }
        for layer in range(self.ATTENTION_LAYERS):
            output['layer{}'.format(layer)] = lit_types.AttentionHeads(
                align=('input_tokens', 'input_tokens'))

        return output
Ejemplo n.º 10
0
MIN_LEN = 3
ITERATION_COUNT = 5
MIN_ERROR_PROBABILITY = 0
ADDITIONAL_CONFIDENCE = 0
IS_ENSEMBLE = 0
LOWERCASE_TOKENS = 0
WEIGHTS = None

model = GecBERTModel(
    vocab_path=VOCAB_PATH,
    model_paths=[MODEL_PATHS],
    max_len=MAX_LEN,
    min_len=MIN_LEN,
    iterations=ITERATION_COUNT,
    min_error_probability=MIN_ERROR_PROBABILITY,
    lowercase_tokens=LOWERCASE_TOKENS,
    model_name=TRANSFORMER_MODEL,
    special_tokens_fix=SPECIAL_TOKENS_FIX,
    log=False,
    confidence=ADDITIONAL_CONFIDENCE,
    is_ensemble=IS_ENSEMBLE,
    weigths=WEIGHTS,
)


def get_model():
    return model


class GrammarRequest(BaseModel):
    text: str
Ejemplo n.º 11
0
import argparse
from utils.helpers import read_lines
from gector.gec_model import GecBERTModel

# -------------------------------------------------------------------- #
# ----------------------INITILIZE MODEL------------------------------- #

model = GecBERTModel(vocab_path='vocab/output_vocabulary',
                     model_paths=[
                         'model/roberta_1_gector.th',
                     ],
                     max_len=50,
                     min_len=3,
                     iterations=5,
                     min_error_probability=0.0,
                     lowercase_tokens=0,
                     model_name='roberta',
                     special_tokens_fix=1,
                     log=False,
                     confidence=0,
                     is_ensemble=0,
                     weigths=None)

# -------------------------------------------------------------------- #


def predict_for_file(input_file, output_file, model, batch_size=32):
    test_data = read_lines(input_file)
    predictions = []
    cnt_corrections = 0
Ejemplo n.º 12
0
class GrammarCorrection:
    def __init__(self):
        """
            vocab_path: Path to the vocab file.
            model_paths: Path to the model file.
            max_len: The max sentence length(all longer will be truncated). type->int
            min_len: The minimum sentence length(all longer will be returned w/o changes) type-> int
            iterations: The number of iterations of the model. type->int
            min_error_probability: Minimum probability for each action to apply.
                                   Also, minimum error probability, as described in the paper. type->float
            lowercase_tokens: The number of iterations of the model. type->int
            model_name: Name of the transformer model.you can chose['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert']
            special_tokens_fix: Whether to fix problem with [CLS], [SEP] tokens tokenization.
                                For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa. type->int
            confidence: How many probability to add to $KEEP token. type->float
            is_ensemble: Whether to do ensembling. type->int
            weigths: Used to calculate weighted average.
            batch_size: The size of hidden unit cell.
            :return:
            """

        if not os.path.exists('./model/bert_0_gector.th'):
            self.download_model()

        self.model = GecBERTModel(vocab_path='./data/output_vocabulary',
                                  model_paths=['./model/bert_0_gector.th'],
                                  max_len=50, min_len=3,
                                  iterations=5,
                                  min_error_probability=0.0,
                                  lowercase_tokens=0,
                                  model_name='bert',
                                  special_tokens_fix=0,
                                  log=False,
                                  confidence=0,
                                  is_ensemble=0,
                                  weigths=None)

    @staticmethod
    def download_model():
        link = "http://imdreamer.oss-cn-hangzhou.aliyuncs.com/bert_0_gector.th"
        with open('./model/bert_0_gector.th', "wb") as f:
            print('Downloading grammatical error correction model [bert_0_gector]! Please wait a minute!')
            response = requests.get(link, stream=True)
            total_length = response.headers.get('content-length')

            if total_length is None:  # no content length header
                f.write(response.content)
            else:
                dl = 0
                total_length = int(total_length)
                for data in response.iter_content(chunk_size=4096):
                    dl += len(data)
                    f.write(data)
                    done = int(50 * dl / total_length)
                    sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
                    sys.stdout.flush()

    @staticmethod
    def language_checker(text):
        matches = tool.check(text)
        correct_text = lc.correct(text, matches)
        return correct_text

    def correct_sentence_by_file(self, input_file='./data/predict_for_file/input.txt',
                                 output_file='./data/predict_for_file/output.txt', batch_size=32):
        test_data = read_lines(input_file)
        predictions = []
        cnt_corrections = 0
        batch = []
        for sent in test_data:
            batch.append(self.language_checker(sent).split())
            if len(batch) == batch_size:
                preds, cnt = self.model.handle_batch(batch)
                predictions.extend(preds)
                cnt_corrections += cnt
                batch = []
        if batch:
            preds, cnt = self.model.handle_batch(batch)
            predictions.extend(preds)
            cnt_corrections += cnt

        with open(output_file, 'w') as f:
            f.write("\n".join([" ".join(x) for x in predictions]) + '\n')
        return cnt_corrections

    def correct_sentence(self, input_string):
        predictions = []
        cnt_corrections = 0
        batch = [self.language_checker(input_string).split()]
        if batch:
            preds, cnt = self.model.handle_batch(batch)
            predictions.extend(preds)
            cnt_corrections += cnt
        return [" ".join(x) for x in predictions][0]
Ejemplo n.º 13
0
from utils.helpers import add_sents_idx, add_tokens_idx, token_level_edits, forward_merge_corrections, backward_merge_corrections
from copy import deepcopy
import pprint
import errant

logging.basicConfig(
    format='%(levelname)s: [%(asctime)s][%(filename)s:%(lineno)d] %(message)s',
    level=logging.INFO)

nlp = spacy.load("en")
annotator = errant.load(lang='en', nlp=nlp)

model = GecBERTModel(
    vocab_path="./data/output_vocabulary",
    model_paths=["./pretrain/roberta_1_gector.th"],
    # model_paths = ["./pretrain/bert_0_gector.th", "./pretrain/roberta_1_gector.th", "./pretrain/xlnet_0_gector.th"],
    model_name="roberta",
    is_ensemble=False,
    iterations=3,
)

DEFAULT_CONFIG = {
    'iterations': 3,
    'min_probability': 0.5,
    'min_error_probability': 0.7,
    'case_sensitive': True,
    'languagetool_post_process': True,
    'languagetool_call_thres': 0.7,
    'whitelist': [],
    'with_debug_info': True
}
Ejemplo n.º 14
0
def get_style_transfer_model():
    return GecBERTModel(vocab_path='vocabulary',
                        model_paths=['roberta_1_gector.th'])