Пример #1
0
def inference(args, dataloader):
    if str(args.model).lower() == 'fcn32s':
        model = VGG16_FCN32s(n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_fcn32s.pth',
                       map_location='cpu'))
    elif str(args.model).lower() == 'fcn8s':
        model = VGG16_FCN8s(n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_fcn8s.pth',
                       map_location='cpu'))
    else:
        model = UNet(n_channels=3, n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_unet.pth', map_location='cpu'))
    #model = nn.DataParallel(model)
    model.eval()
    model.cuda()

    for idx, (images, path) in enumerate(dataloader):
        b = images.size(0)

        predict = model(images.cuda())
        predict = F.softmax(predict.permute(0, 2, 3, 1), dim=-1)
        predict = torch.argmax(predict, dim=-1)
        predict = predict.cpu().numpy()

        for s in range(b):
            pred_img = np.zeros((512, 512, 3)).astype(np.uint8)
            for c in range(len(class_map)):
                pred_img[predict[s] == c] = class_map[c]
            pred_img = Image.fromarray(pred_img)
            pred_img.save(path[s])
        print(f'\t[{(idx+1)*b}/{len(dataloader.dataset)}]', end='  \r')
Пример #2
0
class Session:
    def __init__(self):
        self.show_dir = "../showdir_dark_train"
        self.model_dir = "../models_dark"
        ensure_dir(settings.show_dir)
        ensure_dir(settings.model_dir)
        logger.info('set show dir as %s' % "../showdir_dark_train")
        logger.info('set model dir as %s' % "../models_dark")

        self.net = UNet(3, 3).cuda()
        self.dataset = None
        self.dataloader = None

    def get_dataloader(self, dataset_name):
        self.dataset = ShowDataset(dataset_name)
        self.dataloader = \
                    DataLoader(self.dataset, batch_size=1,
                            shuffle=False, num_workers=1)
        return self.dataloader

    def load_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
            logger.info('Load checkpoint %s' % ckp_path)
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.net.load_state_dict(obj['net'])

    def inf_batch(self, name, batch):
        O = batch['O'].cuda()
        O = Variable(O, requires_grad=False)

        with torch.no_grad():
            derain = self.net(O)

        return derain

    def save_image(self, No, imgs):
        for i, img in enumerate(imgs):
            img = (img.cpu().data * 255).numpy()
            img = np.clip(img, 0, 255)
            img = np.transpose(img, (1, 2, 0))
            # h, w, c = img.shape
            # if i == 3:
            img_file = os.path.join(self.show_dir, '%s.png' % (No))
            cv2.imread(
                os.path.join(
                    "D:\\Desktop\\Code\\pytorch\\RESCAN-master\\dataset\\c\\Rain_200_H\\test",
                    '%s.png' % (No)))
            cv2.imwrite(img_file, img)
import glob
import numpy as np
import torch
import os
import cv2
from unet.unet_model import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('E:/AI_data/ISBI/data/test/*.png')
    # 遍历所有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
Пример #4
0
import torch as th
import os
import cv2
import numpy as np
from unet.unet_model import UNet
import time

unet = UNet(3, 1).to('cuda')
unet.eval()
unet.load_state_dict(th.load('.\checkpoint\\PersonMasker262.pt'))

evalImagePath = 'E:\Person_detection\Dataset\DataSets2017\\u_net\image'
evalMaskPath = 'E:\Person_detection\Pytorch-UNet\eval\mask_coco'
imgs = [os.path.join(evalImagePath, i) for i in os.listdir(evalImagePath)]
for idx, img_i in enumerate(imgs):
    img = np.expand_dims(np.transpose(cv2.imread(img_i), [2, 0, 1]), 0)
    t1 = time.time()
    mask = unet(th.cuda.FloatTensor(img))
    t2 = time.time()
    mask = cv2.resize(
        np.transpose(np.repeat(mask.detach().cpu().numpy()[0, :, :, :], 3, 0),
                     [1, 2, 0]), (412, 412))
    background = np.zeros_like(mask)
    color = np.ones_like(mask)
    color[:, :, 0] = 150
    color[:, :, 1] = 50
    color[:, :, 2] = 170
    mask = np.where(mask > 0.5, color, background)
    img = np.transpose(img[0, :, :, :], [1, 2, 0])
    mask_img = mask + img
    cv2.imwrite(os.path.join(evalMaskPath, '{}.jpg'.format(idx)), mask_img)
Пример #5
0
    args = get_args()
    print(args)

    have_gpu = torch.cuda.is_available()
    print('Have GPU?:{}'.format(have_gpu))

    writer = SummaryWriter(args.tensorboard)

    # --------------------------- using pre-trained params ---------------------------------- #

    # (1) get param from pre-trained model
    # from unet_3up_area.unet.unet_model import UNet as UNet_old
    from unet.unet_model import UNet as UNet_old
    net_old = UNet_old(n_channels=3, n_classes=1)
    net_old.load_state_dict(
        torch.load('../load_model_from_step1_area_branch_with_sknet/CPxx.pth'))
    net_old_dict = net_old.state_dict()

    # (2) our new model
    net = UNet(n_channels=3, n_classes=1)
    net_dict = net.state_dict()

    # # (3) apply pre-trained params in new model
    net_old_dict = {k: v for k, v in net_old_dict.items() if k in net_dict}
    net_dict.update(net_old_dict)  # update params using pre-trained model
    net.load_state_dict(net_dict)  # update the model

    # for name, param in net.named_parameters():
    #     if param.requires_grad:
    #         print(name)