예제 #1
0
def predict(input_text, net_trained, candidate_num=3, output_print=False):
    TEXT = pickle_load(PKL_FILE)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False)
    text = preprocessing_text(input_text)
    text = tokenizer_bert.tokenize(text)
    text.insert(0, "[CLS]")
    text.append("[SEP]")
    token_ids = torch.ones((max_length)).to(torch.int64)
    ids_list = list(map(lambda x: TEXT.vocab.stoi[x], text))
    for i, index in enumerate(ids_list):
        token_ids[i] = index
    ids_list = token_ids.unsqueeze_(0)
    input = ids_list.to(device)
    input_mask = (input != 1)
    outputs, attention_probs = net_trained(input,
                                           token_type_ids=None,
                                           attention_mask=None,
                                           output_all_encoded_layers=False,
                                           attention_show_flg=True)

    offset_tensor = torch.tensor(offset, device=device)
    outputs -= offset_tensor
    if output_print == True: print(outputs)
    _, preds = torch.topk(outputs, candidate_num)
    return preds
예제 #2
0
def tokenizer_with_preprocessing(text):
    tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False)
    text = preprocessing_text(text)
    ret = tokenizer_bert.tokenize(text)  
    return ret