コード例 #1
0
def inference(args, dataloader):
    if str(args.model).lower() == 'fcn32s':
        model = VGG16_FCN32s(n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_fcn32s.pth',
                       map_location='cpu'))
    elif str(args.model).lower() == 'fcn8s':
        model = VGG16_FCN8s(n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_fcn8s.pth',
                       map_location='cpu'))
    else:
        model = UNet(n_channels=3, n_classes=7)
        model.load_state_dict(
            torch.load(f'{args.model_path}/best_unet.pth', map_location='cpu'))
    #model = nn.DataParallel(model)
    model.eval()
    model.cuda()

    for idx, (images, path) in enumerate(dataloader):
        b = images.size(0)

        predict = model(images.cuda())
        predict = F.softmax(predict.permute(0, 2, 3, 1), dim=-1)
        predict = torch.argmax(predict, dim=-1)
        predict = predict.cpu().numpy()

        for s in range(b):
            pred_img = np.zeros((512, 512, 3)).astype(np.uint8)
            for c in range(len(class_map)):
                pred_img[predict[s] == c] = class_map[c]
            pred_img = Image.fromarray(pred_img)
            pred_img.save(path[s])
        print(f'\t[{(idx+1)*b}/{len(dataloader.dataset)}]', end='  \r')
コード例 #2
0
ファイル: test.py プロジェクト: songkq/Polyp-Seg
from os.path import *
import numpy as np
import scipy.misc
import matplotlib.pyplot as plt

batch_size = 1
test_image_dir = '../test/images/'
test_label_dir = '../test/labels/'
checkpoints_dir = '../checkpoints/'
# save_path = 'test_results/'
# if not exists(save_path):
#     os.mkdir(save_path)

net = UNet(n_channels=3, n_classes=1)
net.cuda()
net.eval()

for checkpoint in range(1, 31):
    net.load_state_dict(
        torch.load(checkpoints_dir + 'CP' + str(5 * checkpoint - 4) + '.pth'))

    transform1 = transforms.Compose([ToTensor()])
    test_dataset = Dataset_unet(test_image_dir,
                                test_label_dir,
                                transform=transform1)
    dataloader = DataLoader(test_dataset, batch_size=batch_size)
    dataset_sizes = len(test_dataset)
    batch_num = int(dataset_sizes / batch_size)

    Sensitivity = 0
    Specificity = 0
コード例 #3
0
ファイル: train.py プロジェクト: yangyongjx/WiSLAR
    loss_y = 0
    for (samples, labels) in tqdm(train_data_loader):
        samplesV = Variable(samples.cuda())
        labelsV = Variable(labels.cuda())

        # Forward + Backward + Optimize
        optimizer.zero_grad()
        predict_label = unet(samplesV)

        loss = criterion(predict_label, labelsV)
        print(loss.item())

        loss.backward()
        optimizer.step()

    unet.eval()
    loss_x = 0
    correct_train = 0
    for i, (samples, labels) in enumerate(train_data_loader):
        with torch.no_grad():
            samplesV = Variable(samples.cuda())
            labelsV = Variable(labels.cuda())

            predict_label = unet(samplesV)

            prediction = predict_label.data.max(1)[1]
            correct_train += prediction.eq(labelsV.data.long()).sum()

            loss = criterion(predict_label, labelsV)
            loss_x += loss.item()