from torchvision import transforms
from torchvision.datasets import ImageFolder
import torch.utils.data as Data
import Network
from sklearn.metrics import f1_score, precision_score, recall_score
import time
use_cuda = True
#model_dict = torch.load('../model/finetune.pth')
#model_dict = torch.load('../model/k-fold-finetune.pth')
#model_dict = torch.load('../model/k-fold-finetune-DA.pth')
#model = Network.Net()
#model_dict = torch.load('../model/finetune-alexnet.pth')
model_dict = torch.load('../model/k-fold-finetune-alex.pth')
#model_dict = torch.load('../model/k-fold-finetune-alex-DA.pth')

model = Network.AlexNet()

print('load model parameters')

model.load_state_dict(model_dict)
#读取测试数据

normalize = transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  #将图片转换为Tensor,归一化至[0,1]
    normalize
])
data = ImageFolder('../birds/testing', transform=transform)
test_loader = Data.DataLoader(dataset=data, shuffle=True)