Exemplo n.º 1
0
def test(args):
    model = Unet(1, 1)
    model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
    liver_dataset = LiverDataset("data/val",
                                 transform=x_transforms,
                                 target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)

    save_root = './data/predict'

    model.eval()
    plt.ion()
    index = 0
    with torch.no_grad():
        for x, ground in dataloaders:
            x = x.type(torch.FloatTensor)
            y = model(x)
            x = torch.squeeze(x)
            x = x.unsqueeze(0)
            ground = torch.squeeze(ground)
            ground = ground.unsqueeze(0)
            img_ground = transform_invert(ground, y_transforms)
            img_x = transform_invert(x, x_transforms)
            img_y = torch.squeeze(y).numpy()
            # cv2.imshow('img', img_y)
            src_path = os.path.join(save_root, "predict_%d_s.png" % index)
            save_path = os.path.join(save_root, "predict_%d_o.png" % index)
            ground_path = os.path.join(save_root, "predict_%d_g.png" % index)
            img_ground.save(ground_path)
            # img_x.save(src_path)
            cv2.imwrite(save_path, img_y * 255)
            index = index + 1
Exemplo n.º 2
0
def homework2(test_dir):
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        AddPepperNoise(0.9, p=0.8),
        transforms.ToTensor(),
    ])
    test_data = RMBDataset(data_dir=test_dir, transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=1)

    for i, data in enumerate(test_loader):
        inputs, labels = data  # B C H W
        img_tensor = inputs[0, ...]  # C H W
        img = transform_invert(img_tensor, test_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()
Exemplo n.º 3
0
def homework1(test_dir):
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        # 1 裁剪
        #transforms.CenterCrop(120),
        # 2 翻转
        #transforms.RandomHorizontalFlip(p=1),
        # 3 旋转
        #transforms.RandomRotation(45),
        # 4 色相
        #transforms.ColorJitter(hue=0.4),
        # 5 饱和度
        #transforms.ColorJitter(saturation=50),
        # 6 灰度图
        #transforms.Grayscale(3),
        # 7 错切
        #transforms.RandomAffine(0,shear=45),
        # 8 缩放
        #transforms.RandomAffine(0,scale=(0.5,0.5)),
        # 9 平移
        #transforms.RandomAffine(0,translate=(0.5,0)),
        # 10 遮挡
        #transforms.ToTensor(),
        #transforms.RandomErasing(p=0.5,scale=(0.1,0.4),value=0),
        transforms.ToTensor(),
    ])
    # 构建MyDataset实例
    test_data = RMBDataset(data_dir=test_dir, transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=1)

    for i, data in enumerate(test_loader):
        inputs, labels = data  # B C H W
        img_tensor = inputs[0, ...]  # C H W
        img = transform_invert(img_tensor, test_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()
print("像素值:", output_tensor_3)
print("=======================================")
output_tensor_4 = conv_layer4(input_tensor)
print("卷积前尺寸:{}\n卷积后尺寸:{}".format(input_tensor.shape, output_tensor_4.shape))
print("像素值:", output_tensor_4)

# ================================ 2 ================================

import matplotlib.pyplot as plt
from common_tools import transform_invert
from torchvision import transforms
from PIL import Image

lena_img = "lena.png"
img = Image.open(lena_img).convert('RGB')
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)

conv_layer = nn.Conv3d(3, 1, (3, 3, 3), padding=(1, 0, 0), bias=False)
nn.init.xavier_normal_(conv_layer.weight.data)
img_tensor.unsqueeze_(dim=2)
img_conv = conv_layer(img_tensor)

#img_conv = transform_invert(img_conv[0, 0:1, ...], img_transform)
img_conv = transform_invert(img_conv[:, :, ...], img_transform)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
plt.subplot(122).imshow(img_conv, cmap='gray')
plt.subplot(121).imshow(img_raw)
plt.show()