def predict(input_text, net_trained, candidate_num=3, output_print=False): TEXT = pickle_load(PKL_FILE) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False) text = preprocessing_text(input_text) text = tokenizer_bert.tokenize(text) text.insert(0, "[CLS]") text.append("[SEP]") token_ids = torch.ones((max_length)).to(torch.int64) ids_list = list(map(lambda x: TEXT.vocab.stoi[x], text)) for i, index in enumerate(ids_list): token_ids[i] = index ids_list = token_ids.unsqueeze_(0) input = ids_list.to(device) input_mask = (input != 1) outputs, attention_probs = net_trained(input, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=True) offset_tensor = torch.tensor(offset, device=device) outputs -= offset_tensor if output_print == True: print(outputs) _, preds = torch.topk(outputs, candidate_num) return preds
def tokenizer_with_preprocessing(text): tokenizer_bert = BertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False) text = preprocessing_text(text) ret = tokenizer_bert.tokenize(text) return ret