def test_unet():
    net = nets.UNet(num_channels=1, antialias=True)

    # sets_ = get_datasets('DRIVE')
    sets_ = get_datasets('STARE')
    train_dataset = sets_['train']
    train_dataset.return_mask = True  # override
    # tensors, supposedly
    # img, mask, target = zip(*[train_dataset[i] for i in range(2)])
    img, target = zip(*[train_dataset[i] for i in range(2)])
    mask = None
    import torch
    from torchvision.utils import make_grid

    img = torch.stack(img)

    with torch.no_grad():
        pred_mask = net(img)
        pred_mask = F.softmax(pred_mask, 1)

    img = make_grid(img)
    pred_mask = make_grid(pred_mask)
    target = make_grid([t.unsqueeze(0) for t in target])

    import matplotlib.pyplot as plt

    mean_ = train_dataset.transforms[-2].mean
    std_ = train_dataset.transforms[-2].std
    img = (std_[1] * img + mean_[1]).clamp(0, 1)
    img = np.moveaxis(img.numpy(), 0, -1)

    if mask is not None:
        mask = make_grid([m.unsqueeze(0) for m in mask])
        mask = np.moveaxis(mask.numpy(), 0, -1)

    pred_mask = pred_mask.numpy()[0]
    target = target.numpy()[0]

    fig = plt.figure(figsize=(11, 6))
    plt.subplot(2, 2, 1)
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(2, 2, 2)
    plt.imshow(pred_mask)
    plt.axis('off')

    plt.subplot(2, 2, 3)
    plt.imshow(target)
    plt.axis('off')

    fig.tight_layout()
    plt.show()
Exemple #2
0
 def build_net(self, mode='train'):
     model = nets.UNet(self.num_classes,
                       mode=mode,
                       upsample_mode=self.upsample_mode,
                       input_channel=self.input_channel,
                       use_bce_loss=self.use_bce_loss,
                       use_dice_loss=self.use_dice_loss,
                       class_weight=self.class_weight,
                       ignore_index=self.ignore_index)
     inputs = model.generate_inputs()
     model_out = model.build_net(inputs)
     outputs = OrderedDict()
     if mode == 'train':
         self.optimizer.minimize(model_out)
         outputs['loss'] = model_out
     elif mode == 'eval':
         outputs['loss'] = model_out[0]
         outputs['pred'] = model_out[1]
         outputs['label'] = model_out[2]
         outputs['mask'] = model_out[3]
     else:
         outputs['pred'] = model_out[0]
         outputs['logit'] = model_out[1]
     return inputs, outputs
Exemple #3
0
    def __init__(self, param):
        super(N2CModel, self).__init__()
        self.device = param['device']
        self.model_path = param['model_path']

        if param['mode'] == 'train':
            self.batch_size = param['batch_size']

            if param['lpfir'] or param['no_rf']:
                self.train_data_transform = transforms.Compose([
                    util.transforms.Resize(param['img_size'], param['inC']),
                    util.transforms.AddNoise(param['sigma']),
                    util.transforms.ToTensor()
                ])
            else:
                self.train_data_transform = transforms.Compose([
                    util.transforms.Resize(param['img_size'], param['inC']),
                    util.transforms.AddNoise(param['sigma']),
                    util.transforms.RotateFlip(),
                    util.transforms.ToTensor()
                ])

            self.train_dataset = datasets.N2CDataset(
                dataset_path=param['dataset_path'],
                mode=param['mode'],
                transform=self.train_data_transform)
            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset, batch_size=self.batch_size, shuffle=True)

            self.net = nets.UNet(inC=param['inC'],
                                 midC=param['midC']).to(param['device'])

            self.learning_rate = param['lr']
            self.optimizer = optim.Adam(self.net.parameters(), param['lr'])
            self.criterion = nn.MSELoss()

        elif param['mode'] == 'test':
            self.batch_size = param['batch_size']

            self.test_data_transform = transforms.Compose([
                util.transforms.Resize(param['img_size'], param['inC']),
                util.transforms.AddNoise(param['sigma']),
                util.transforms.ToTensor()
            ])

            self.test_dataset = datasets.N2CDataset(
                dataset_path=param['dataset_path'],
                mode=param['mode'],
                transform=self.test_data_transform)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset, batch_size=self.batch_size)

            self.net = nets.UNet(inC=param['inC'],
                                 midC=param['midC']).to(param['device'])

            self.criterion = nn.MSELoss()

        elif param['mode'] == 'run':
            self.run_data_transform = transforms.Compose([
                util.transforms.Resize(None, param['inC']),
                util.transforms.AddNoise(param['sigma']),
                util.transforms.ToTensor()
            ])

            self.run_dataset = datasets.N2CDataset(
                dataset_path=param['dataset_path'],
                mode=param['mode'],
                transform=self.run_data_transform,
                data_path=param['data_path'])
            self.run_loader = torch.utils.data.DataLoader(
                self.run_dataset, batch_size=len(self.run_dataset))

            self.net = nets.UNet(inC=param['inC'],
                                 midC=param['midC']).to(param['device'])