コード例 #1
0
ファイル: test.py プロジェクト: dsnaaa/pytorch_mnist
def test():
    with torch.no_grad():
        net = CNN(1, 10)
        net.load_state_dict(torch.load('D:/bd/model.tar'))
        eval_acc = 0
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            out = net(img)
            _, pred = torch.max(out, 1)
            num_correct = (pred == label).sum()
            eval_acc += num_correct.item()
        print('Acc: {:.2f}%'.format(eval_acc / len(test_dataset) * 100))
コード例 #2
0
    def predict():
        log = '程序开始时间' + str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + '\n'
        print('程序开始时间', datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        print('载入模型')
        cnn = CNN()
        if constant.USE_GPU:
            cnn = cnn.cuda()
        cnn.load_state_dict(torch.load(constant.LAST_MODEL_PATH))
        test_x, test_y, images_name = predict_data_preprocessing([path.get()])
        left = 0
        mid = 0
        right = 0
        delete_file_folder(path2.get())
        log = '开始检测时间' + str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + '\n'
        print('开始检测时间', datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        for i, x in enumerate(test_x):
            x.unsqueeze_(0)
            x.unsqueeze_(0)
            if constant.USE_GPU:
                x = x.cuda()
            test_output = cnn(x)[0]
            pre_y = torch.max(test_output, 1)[1].data.cpu().squeeze().numpy()
            if pre_y == 0:
                if (i % 3 == 0):
                    left = 1
                elif i % 3 == 1:
                    mid = 1
                else:
                    right = 1

                # print("Failed:" + images_name[i])

            if i % 3 == 2:
                if left | right | mid:
                    draw_image(left, right, mid, path.get() + images_name[i][0:-4])
                    log += '不合格 '
                    if left: log += '1,'
                    if mid: log += '2,'
                    if right: log += '3,'
                    log=log[:-1]+'\n'
                else:
                    log += '合格\n'
                left = 0
                mid = 0
                right = 0
        print('结束时间', str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
        log += '结束时间' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + '\n'
        f = open(path2.get() + '/result.txt', 'w')
        f.write(log)
        f.close()
        subprocess.Popen('explorer "%s"' % path2.get().replace('/', '\\'))
コード例 #3
0
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import numpy as np
from PIL import Image
from train import CNN

net = CNN(1, 10)
net.load_state_dict(torch.load('D:/bd/model.tar'))
input_image = 'D:/4.png'

im = Image.open(input_image).resize((28, 28))  #取图片数据
im = im.convert('L')  #灰度图
im_data = np.array(im)

im_data = torch.from_numpy(im_data).float()

im_data = im_data.view(1, 1, 28, 28)
out = net(im_data)
_, pred = torch.max(out, 1)

print('预测的数字是:{}。'.format(pred))