import torch from torch import nn from torch.utils.data import DataLoader from torchvision import transforms import os import UNet import MKDataset path = r'D:\data\VOCtest_06-Nov-2007\VOCdevkit\VOC2007' module = r'module.pkl' img_save_path = r'D:\train_img' batch = 1 net = UNet.MainNet().cuda() optimizer = torch.optim.Adam(net.parameters()) loss_func = nn.BCELoss() dataloader = DataLoader(MKDataset.MKDataset(path), batch_size=4, shuffle=True) if os.path.exists(module): net.load_state_dict(torch.load(module)) print('module is loaded !') if not os.path.exists(img_save_path): os.mkdir(img_save_path) for i, (xs, ys) in enumerate(dataloader): xs = xs.cuda() ys = ys.cuda() xs_ = net(xs)
import torchvision import UNet import os import numpy as np from PIL import Image import torch from torchvision import transforms transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) module = 'model/module03.pkl' path = "pic_test" img_list = os.listdir(path) net = UNet.MainNet(16).cuda() if os.path.exists(module): net.load_state_dict(torch.load(module)) net.eval() # img_name = random.choice(img_list) for img_name in img_list: img = Image.open(os.path.join(path, img_name)) img = img.resize((512, 512), 1) data = transform(img).unsqueeze(0).cuda() out_img = net(data).squeeze(0) img_save = transforms.ToPILImage()(out_img.cpu()) img_save.save("predict/" + img_name) del img, data, out_img, img_save print("%s保持完毕" % img_name)