示例#1
0
def get_preprocessor(ptr_config_info, model_config):
    with open(ptr_config_info.vocab, mode='rb') as io:
        vocab = pickle.load(io)

    if model_config.type == 'etri':
        ptr_tokenizer = ETRITokenizer.from_pretrained(ptr_config_info.tokenizer, do_lower_case=False)
        pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer.tokenize, pad_fn=pad_sequence)
    elif model_config.type == 'skt':
        ptr_tokenizer = SentencepieceTokenizer(ptr_config_info.tokenizer)
        pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=pad_sequence)
    return preprocessor
示例#2
0
def predict(sentence1, sentence2):
    ptr_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/pretrained"
    data_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/data"
    caseType = "skt"
    model_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/experiments/base_model"
    checkpoint_model_file = "best_skt.tar"
    
    # ptr_dir = "BERT_pairwise_text_classification/pretrained"
    # data_dir = "BERT_pairwise_text_classification/data"
    # caseType = "skt"
    # model_dir = "BERT_pairwise_text_classification/experiments/base_model"
    # checkpoint_model_file = "best_skt.tar"
    
    # ptr_dir = "pretrained"
    # data_dir = "data"
    # caseType = "skt"
    # model_dir = "experiments/base_model"
    # checkpoint_model_file = "best_skt.tar"
    
    ptr_dir = Path(ptr_dir)
    data_dir = Path(data_dir)
    model_dir = Path(model_dir)
    checkpoint_model_file = Path(checkpoint_model_file)
    
    ptr_config = Config(ptr_dir / 'config_skt.json')
    data_config = Config(data_dir / 'config.json')
    model_config = Config(model_dir / 'config.json')
    
    # vocab
    with open(os.path.join(ptr_dir, ptr_config.vocab), mode='rb') as io:
        vocab = pickle.load(io)
    
    
    ptr_tokenizer = SentencepieceTokenizer(os.path.join(ptr_dir, ptr_config.tokenizer))
    pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
    preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=pad_sequence)
    
    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint(checkpoint_model_file)
    config = BertConfig(os.path.join(ptr_dir, ptr_config.config))
    model = PairwiseClassifier(config, num_classes=model_config.num_classes, vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    device = torch.device('cpu')
    model.to(device)
    
    transform = preprocessor.preprocess
    if model.training:
        model.eval()
        
    indices, token_types = [torch.tensor([elm]) for elm in transform(sentence1, sentence2)]

    with torch.no_grad():
        label = model(indices, token_types)
    label = label.max(dim=1)[1]
    label = label.numpy()[0]

    return label
示例#3
0
def get_tokenizer(dataset_config, model_config):
    with open(dataset_config.vocab, mode="rb") as io:
        vocab = pickle.load(io)
    pad_sequence = PadSequence(
        length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token)
    )
    tokenizer = Tokenizer(vocab=vocab, split_fn=split_morphs, pad_fn=pad_sequence)
    return tokenizer
示例#4
0
 def prepare_data(self):
     pad_sequence = PadSequence(length=self.hparams.length,
                                pad_val=self.vocab.to_indices(
                                    self.vocab.padding_token))
     tokenizer = Tokenizer(vocab=self.vocab,
                           split_fn=split_to_jamo,
                           pad_fn=pad_sequence)
     self.tokenizer = tokenizer
    model_dir = Path(args.model_dir)

    ptr_config = Config(ptr_dir / 'config_{}.json'.format(args.type))
    data_config = Config(data_dir / 'config.json')
    model_config = Config(model_dir / 'config.json')

    # vocab
    with open(ptr_config.vocab, mode='rb') as io:
        vocab = pickle.load(io)

    # tokenizer
    if args.type == 'etri':
        ptr_tokenizer = ETRITokenizer.from_pretrained(ptr_config.tokenizer,
                                                      do_lower_case=False)
        pad_sequence = PadSequence(length=model_config.length,
                                   pad_val=vocab.to_indices(
                                       vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab,
                                    split_fn=ptr_tokenizer.tokenize,
                                    pad_fn=pad_sequence)
    elif args.type == 'skt':
        ptr_tokenizer = SentencepieceTokenizer(ptr_config.tokenizer)
        pad_sequence = PadSequence(length=model_config.length,
                                   pad_val=vocab.to_indices(
                                       vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab,
                                    split_fn=ptr_tokenizer,
                                    pad_fn=pad_sequence)

    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
示例#6
0
from model.split import split_morphs
from model.utils import Tokenizer, PadSequence

app = Flask(__name__)
app.config.from_pyfile("config.py")
app.database = create_engine(app.config["DB_URL"],
                             encoding="utf-8",
                             max_overflow=0)

# preprocessor & model
num_classes = app.config["MODEL"]["num_classes"]
max_length = app.config["MODEL"]["length"]

with open("model/checkpoint/vocab.pkl", mode="rb") as io:
    vocab = pickle.load(io)
pad_sequence = PadSequence(length=max_length,
                           pad_val=vocab.to_indices(vocab.padding_token))
tokenizer = Tokenizer(vocab=vocab, split_fn=split_morphs, pad_fn=pad_sequence)

model = SenCNN(num_classes=app.config["MODEL"]["num_classes"], vocab=vocab)
ckpt = torch.load("model/checkpoint/best.tar",
                  map_location=torch.device("cpu"))
model.load_state_dict(ckpt["model_state_dict"])
model.eval()


@app.route("/alive_check", methods=["GET"])
def alive_check():
    return "alive", 200


@app.route("/inference", methods=["POST"])