-
Notifications
You must be signed in to change notification settings - Fork 0
/
fakenews.py
39 lines (26 loc) · 1.05 KB
/
fakenews.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn.functional as F
from transformers import BertTokenizer
from preprocess import preprocess_text
def load_specific_model():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = BertTokenizer.from_pretrained("./saved_model/")
model = torch.load("model_after_train.pt", map_location=device)
model.eval()
return device, model, tokenizer
def detect_fake(text, device, model, tokenizer):
text_parts = preprocess_text(text, device, tokenizer)
overall_output = torch.zeros((1, 2)).to(device)
try:
for part in text_parts:
if len(part) > 0:
overall_output += model(part.reshape(1, -1))[0]
except RuntimeError:
print("GPU out of memory, skipping this entry.")
overall_output = F.softmax(overall_output[0], dim=-1)
value, result = overall_output.max(0)
term = False
if result.item() == 0:
term = True
print("Is real - {} at {}%".format(term, value.item() * 100))
return term, value.item() * 100