Exemple #1
0
def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'):
    if NET_TYPE == 'ResNet':
        # TODO
        net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False)
    elif NET_TYPE == 'skip':
        net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d,
                                            num_channels_up =   [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u,
                                            num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 
                                            upsample_mode=upsample_mode, downsample_mode=downsample_mode,
                                            need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun)

    elif NET_TYPE == 'texture_nets':
        net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad)

    elif NET_TYPE =='UNet':
        net = UNet(num_input_channels=input_depth, num_output_channels=3, 
                   feature_scale=4, more_layers=0, concat_x=False,
                   upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True)
    elif NET_TYPE == 'identity':
        assert input_depth == 3
        net = nn.Sequential()
    else:
        assert False

    return net
Exemple #2
0
    def initialize(self, opt):
        input_depth = opt.input_nc
        output_depth = opt.output_nc
        BaseModel.initialize(self, opt)
        self.net_shared = skip(input_depth, num_channels_down = [64, 128, 256, 256, 256],
                        num_channels_up   = [64, 128, 256, 256, 256],
                        num_channels_skip = [4, 4, 4, 4, 4],
                        upsample_mode=['nearest', 'nearest', 'bilinear', 'bilinear', 'bilinear'],
                        need_sigmoid=True, need_bias=True, pad='reflection')
        self.netDec_b = ResNet_decoders(opt.ngf, output_depth)


        self.net_input = self.get_noise(self.opt.input_nc, 'noise', (self.opt.fineSize, self.opt.fineSize))
        self.net_input_saved = self.net_input.detach().clone()
        self.noise = self.net_input.detach().clone()

        use_sigmoid = opt.no_lsgan
        self.netD = networks.define_D(opt.output_nc, opt.ndf,
                                        opt.which_model_netD,
                                        opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if len(self.gpu_ids) > 0:
            dtype = torch.cuda.FloatTensor
            self.net_input = self.net_input.type(dtype).detach()
            self.net_shared = self.net_shared.type(dtype)
            self.netDec_b = self.netDec_b.type(dtype)
            self.netD = self.netD.type(dtype)

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netDec_b, 'Dec_b', which_epoch)
            self.load_network(self.net_shared, 'Net_shared', which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', which_epoch)

        if self.isTrain:
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.mse = torch.nn.MSELoss()

            # initialize optimizers
            self.optimizer_Net = torch.optim.Adam(
                itertools.chain(self.net_shared.parameters(), self.netDec_b.parameters()),
                                                  lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_b = torch.optim.Adam(self.netD.parameters(), lr=0.0002, betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_Net)
            self.optimizers.append(self.optimizer_D_b)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.net_shared)
        networks.print_network(self.netDec_b)
        networks.print_network(self.netD)
        print('-----------------------------------------------')
Exemple #3
0
def get_network_and_input(img_shape,
                          input_depth=32,
                          pad='reflection',
                          upsample_mode='bilinear',
                          use_interpolate=True,
                          align_corners=False,
                          act_fun='LeakyReLU',
                          skip_n33d=128,
                          skip_n33u=128,
                          skip_n11=4,
                          num_scales=5,
                          downsample_mode='stride',
                          INPUT='noise'):  # 'meshgrid'
    """ Getting the relevant network and network input (based on the image shape and input depth)
        We are using the same default params as in DIP article
        img_shape - the image shape (ch, x, y)
    """
    n_channels = img_shape[0]
    net = skip(input_depth,
               n_channels,
               num_channels_down=[skip_n33d] *
               num_scales if isinstance(skip_n33d, int) else skip_n33d,
               num_channels_up=[skip_n33u] *
               num_scales if isinstance(skip_n33u, int) else skip_n33u,
               num_channels_skip=[skip_n11] *
               num_scales if isinstance(skip_n11, int) else skip_n11,
               upsample_mode=upsample_mode,
               use_interpolate=use_interpolate,
               align_corners=align_corners,
               downsample_mode=downsample_mode,
               need_sigmoid=True,
               need_bias=True,
               pad=pad,
               act_fun=act_fun).type(dtype)
    net_input = get_noise(input_depth, INPUT,
                          img_shape[1:]).type(dtype).detach()
    return net, net_input
Exemple #4
0
def allstart(ori_img_path,pol_img_path,mask_path,recover_img_path,pic_num):
	
	
	torch.backends.cudnn.enabled = True
	torch.backends.cudnn.benchmark =True
	dtype = torch.cuda.FloatTensor

	#PLOT = True
	PLOT = False
	imsize=-1
	#dim_div_by = 64
	dim_div_by =32
	dtype = torch.cuda.FloatTensor


	img_path  = pol_img_path
	NET_TYPE = 'skip_depth6' # one of skip_depth4|skip_depth2|UNET|ResNet



	img_pil, img_np = get_image(img_path, imsize)
	img_mask_pil, img_mask_np = get_image(mask_path, imsize)



	img_mask_pil = crop_image(img_mask_pil, dim_div_by)
	img_pil      = crop_image(img_pil,      dim_div_by)
	img_np      = pil_to_np(img_pil)
	img_mask_np = pil_to_np(img_mask_pil)

	img_mask_var = np_to_var(img_mask_np).type(dtype)
	#plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11)


	pad = 'reflection' # 'zero'
	OPT_OVER = 'net'
	OPTIMIZER = 'adam'


	if True:
	    INPUT = 'noise'
	    input_depth = 32
	    LR = 0.01 
	    num_iter = 100000	
	    param_noise = False
	    show_every = 20000  
	    figsize = 5
	    
	    net = skip(input_depth, img_np.shape[0], 
		       num_channels_down = [16, 32, 64, 128, 128],
		       num_channels_up =   [16, 32, 64, 128, 128],
		       num_channels_skip =    [0, 0, 0, 0, 4],  
		       filter_size_up = 7, filter_size_down = 7, 
		       upsample_mode='nearest', filter_skip_size=1,
		       need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)


	net = net.type(dtype)
	net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)


	# Compute number of parameters
	s  = sum(np.prod(list(p.size())) for p in net.parameters())
	print ('Number of params: %d' % s)

	# Loss
	mse = torch.nn.MSELoss().type(dtype)

	img_var = np_to_var(img_np).type(dtype)
	mask_var = np_to_var(img_mask_np).type(dtype)


	i = 0
	def closure():
		global i
	    
		if param_noise:
			for n in [x for x in net.parameters() if len(x.size()) == 4]:
				n.data += n.data.clone().normal_()*n.data.std()/50
	    
		out = net(net_input)
	   
		total_loss = mse(out * mask_var, img_var*mask_var)
		total_loss.backward()
		
		print ('Iteration %05d    Loss %f' % (i, total_loss.data[0]), '\r', end='')
		if  PLOT and i % show_every == 0:
			out_np = var_to_np(out)
			plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
		
		i += 1

		return total_loss

	#print('picture number:',pic_num)
	logging.info('picture number:%s'%pic_num)

	p = get_params(OPT_OVER, net, net_input)
	optimize(OPTIMIZER, p, closure, LR, num_iter)


	out_np = var_to_np(net(net_input))
	r=out_np[0,:,:]
	g=out_np[1,:,:]
	b=out_np[2,:,:]
	r=r*255
	g=g*255
	b=b*255
	out=cv2.merge([b,g,r])
	cv2.imwrite(recover_img_path, out)
        restart = True

        mse = torch.nn.MSELoss().type(dtype)

        # with torch.no_grad():

        img_var = Variable(np_to_torch(img_np).type(dtype))
        mask_var = Variable(np_to_torch(img_mask_np).type(dtype))
        # img_var = torch.tensor(img_np, requires_grad=False).type(dtype)
        # mask_var = torch.tensor(img_mask_np, requires_grad=False).type(dtype)

        # LR finder

        net = skip(input_depth, img_np.shape[0], 
                num_channels_down = [128]*5,
                num_channels_up   = [128]*5,
                num_channels_skip = [128]*5,
                upsample_mode='nearest', filter_skip_size=1, filter_size_up=3, filter_size_down=3,
                need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)

        net = net.type(dtype)

        net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)


        mse_error = []
        start = time.time()
        i = 0
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        p = get_params(OPT_OVER, net, net_input)
# In[ ]:


if 'vase.png' in img_path:
    INPUT = 'meshgrid'
    input_depth = 2
    LR = 0.01 
    num_iter = 5001
    param_noise = False
    show_every = 50
    figsize = 5
    reg_noise_std = 0.03
    
    net = skip(input_depth, img_np.shape[0], 
               num_channels_down = [128] * 5,
               num_channels_up   = [128] * 5,
               num_channels_skip = [0] * 5,  
               upsample_mode='nearest', filter_skip_size=1, filter_size_up=3, filter_size_down=3,
               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
    
elif ('kate.png' in img_path) or ('peppers.png' in img_path):
    # Same params and net as in super-resolution and denoising
    INPUT = 'noise'
    input_depth = 32
    LR = 0.01 
    num_iter = 6001
    param_noise = False
    show_every = 50
    figsize = 5
    reg_noise_std = 0.03
    
    net = skip(input_depth, img_np.shape[0], 
        img_masked = img_np * img_mask_np
        mask_var = np_to_torch(img_mask_np).type(dtype)

        # Visualization
        if args.plot:
            plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3, 11)

        if args.net == 'default':
            from models.skip import skip
            net = skip(num_input_channels=args.input_depth,
                       num_output_channels=1,
                       num_channels_down=[128] * 5,
                       num_channels_up=[128] * 5,
                       num_channels_skip=[4] * 5,
                       upsample_mode='bilinear',
                       downsample_mode='stride',
                       need_sigmoid=True,
                       need_bias=True,
                       pad='reflection',
                       act_fun='LeakyReLU')

        elif args.net == 'NAS':
            from models.skip_search_up import skip
            if args.i_NAS in [249, 250, 251]:
                exit(1)
            net = skip(model_index=args.i_NAS,
                       num_input_channels=args.input_depth,
                       num_output_channels=1,
                       num_channels_down=[128] * 5,
                       num_channels_up=[128] * 5,
Exemple #8
0
pad = 'reflection'  # 'zero'
OPT_OVER = 'net'
OPTIMIZER = 'adam'

INPUT = 'noise'
input_depth = 2
LR = 0.01
num_iter = 3001
param_noise = False
show_every = 500
figsize = 5

net = skip(input_depth,
           img_np.shape[0],
           need_sigmoid=True,
           need_bias=True,
           pad=pad,
           act_fun='LeakyReLU').type(dtype)

net = net.type(dtype)
net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)

# Compute number of parameters
s = sum(np.prod(list(p.size())) for p in net.parameters())
print('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_var = np_to_var(img_np).type(dtype)
mask_var = np_to_var(img_mask_np).type(dtype)
Exemple #9
0
def save_img(path, img):
    imgdata = img.data.cpu().numpy()
    io.imsave(path, imgdata / imgdata.max())


if __name__ == '__main__':
    #net = ImageUNet(256, batch_size=1, noise_reg=1.0/3.0).cuda()
    input_depth = 3
    input_noise = torch.randn(1, input_depth, 256, 256).cuda()
    input_noise.detach()
    net = skip(input_depth,
               3,
               num_channels_down=[8, 16, 32, 64, 128],
               num_channels_up=[8, 16, 32, 64, 128],
               num_channels_skip=[0, 0, 0, 4, 4],
               upsample_mode='bilinear',
               need_sigmoid=True,
               need_bias=True,
               pad='reflection',
               act_fun='LeakyReLU')
    net.cuda()

    #optimizer = optim.Adam(net.parameters(), lr=0.01)

    theta_size = 30
    #rec_image = (io.imread('data/slice_058.png')[:, :, 0]/255.0).astype('float32')
    rec_image = (io.imread('data/slice_038.png') / 255.0).astype('float32')
    sinogram = radon(torch.tensor(rec_image),
                     thetas=torch.linspace(0, np.pi, steps=theta_size)).cuda()
    noisy_sinogram = add_gaussian_noise(sinogram, sigma=1.0)
Exemple #10
0
    def perform_inpainting(self):
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dtype = torch.float32  # dtype = torch.cuda.FloatTensor

        PLOT = self.args.plot
        imsize = -1
        dim_div_by = 64

        # Path to input image, mask and output
        img_path = (str(Path(__file__).resolve().parents[1]) +
                    "/data/Input data/" + self.image)
        mask_path = (str(Path(__file__).resolve().parents[1]) +
                     "/data/Input data/" + self.mask)
        if self.args.tuning == "basic":
            if self.vase_or_kate_or_library == "library":
                folder = (
                    str(Path(__file__).resolve().parents[1]) +
                    "/data/Output data/Hyperparameter optimization/Basic/" +
                    self.vase_or_kate_or_library + "/" + self.NET_TYPE + "/")
                Path(folder).mkdir(parents=True, exist_ok=True)
            else:
                folder = (
                    str(Path(__file__).resolve().parents[1]) +
                    "/data/Output data/Hyperparameter optimization/Basic/" +
                    self.vase_or_kate_or_library + "/")
                Path(folder).mkdir(parents=True, exist_ok=True)
            outp_path = folder + "/plotout"
        elif self.args.tuning == "advanced":
            print(
                "lr =",
                self.lr,
                "param_noise =",
                self.param_noise,
                "reg_noise_std =",
                self.reg_noise_std,
            )
            folder = (
                str(Path(__file__).resolve().parents[1]) +
                "/data/Output data/Hyperparameter optimization/Advanced/" +
                self.image + "/lr=" + str(self.lr) + ", param_noise=" +
                str(self.param_noise) + ", reg_noise_std=" +
                str(self.reg_noise_std) + "/")
            Path(folder).mkdir(parents=True, exist_ok=True)
            outp_path = folder + "/plotout"
        else:
            folder = (str(Path(__file__).resolve().parents[1]) +
                      "/data/Output data/" + self.image.split(".")[0] + "/")
            Path(folder).mkdir(parents=True, exist_ok=True)
            outp_path = folder + "/plotout"

        # Load mask
        img_pil, img_np = get_image(img_path, imsize)
        img_mask_pil, img_mask_np = get_image(mask_path, imsize)

        # Center crop
        img_mask_pil = crop_image(img_mask_pil, dim_div_by)
        img_pil = crop_image(img_pil, dim_div_by)

        img_np = pil_to_np(img_pil)
        img_mask_np = pil_to_np(img_mask_pil)

        # Visualize
        if PLOT:
            plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3, 11)

        # Setup
        pad = "reflection"  # 'zero'
        OPT_OVER = "net"
        OPTIMIZER = "adam"
        num_iter = self.args.num_iter
        if self.args.tuning == "advanced":
            save_every = int(num_iter / 4)
        else:
            save_every = self.args.save_every

        if self.vase_or_kate_or_library == "vase":
            INPUT = "meshgrid"
            input_depth = 2
            LR = self.lr if self.lr else 0.01
            # num_iter = 5001
            param_noise = self.param_noise if self.param_noise else False
            # save_every = 50
            figsize = 32  # changed from 5
            reg_noise_std = self.reg_noise_std if self.reg_noise_std else 0.03

            net = (skip(
                input_depth,
                img_np.shape[0],
                num_channels_down=[128] * 5,
                num_channels_up=[128] * 5,
                num_channels_skip=[0] * 5,
                upsample_mode="nearest",
                filter_skip_size=1,
                filter_size_up=3,
                filter_size_down=3,
                need_sigmoid=True,
                need_bias=True,
                pad=pad,
                act_fun="LeakyReLU",
            ).type(dtype).to(device))

        elif self.vase_or_kate_or_library == "kate":
            # Same params and net as in super-resolution and denoising
            INPUT = "noise"
            input_depth = 32
            # num_iter = 6001
            LR = 0.01

            param_noise = False
            # save_every = 50
            figsize = 5
            reg_noise_std = 0.03

            net = (skip(
                input_depth,
                img_np.shape[0],
                num_channels_down=[128] * 5,
                num_channels_up=[128] * 5,
                num_channels_skip=[128] * 5,
                filter_size_up=3,
                filter_size_down=3,
                upsample_mode="nearest",
                filter_skip_size=1,
                need_sigmoid=True,
                need_bias=True,
                pad=pad,
                act_fun="LeakyReLU",
            ).type(dtype).to(device))

        elif self.vase_or_kate_or_library == "library":
            INPUT = "noise"
            input_depth = 1
            # num_iter = 3001
            # save_every = 50
            figsize = 8
            reg_noise_std = 0.00
            param_noise = True

            if "skip" in self.NET_TYPE:
                depth = int(self.NET_TYPE[-1])
                net = skip(
                    input_depth,
                    img_np.shape[0],
                    num_channels_down=[16, 32, 64, 128, 128, 128][:depth],
                    num_channels_up=[16, 32, 64, 128, 128, 128][:depth],
                    num_channels_skip=[0, 0, 0, 0, 0, 0][:depth],
                    filter_size_up=3,
                    filter_size_down=5,
                    filter_skip_size=1,
                    upsample_mode="nearest",  # downsample_mode='avg',
                    need1x1_up=False,
                    need_sigmoid=True,
                    need_bias=True,
                    pad=pad,
                    act_fun="LeakyReLU",
                )

                LR = 0.01

            elif self.NET_TYPE == "UNET":

                net = UNet(
                    num_input_channels=input_depth,
                    num_output_channels=3,
                    feature_scale=8,
                    more_layers=1,
                    concat_x=False,
                    upsample_mode="deconv",
                    pad="zero",
                    norm_layer=torch.nn.InstanceNorm2d,
                    need_sigmoid=True,
                    need_bias=True,
                )

                LR = 0.001
                param_noise = False

            elif self.NET_TYPE == "ResNet":

                net = ResNet(
                    input_depth,
                    img_np.shape[0],
                    8,
                    32,
                    need_sigmoid=True,
                    act_fun="LeakyReLU",
                )

                LR = 0.001
                param_noise = False

            else:
                assert False
        else:
            assert False

        net = net.type(dtype).to(device)
        net_input = (get_noise(input_depth, INPUT,
                               img_np.shape[1:]).type(dtype).to(device))

        # Compute number of parameters
        s = sum(np.prod(list(p.size())) for p in net.parameters())
        # print("Number of params: %d" % s)

        # Loss
        mse = torch.nn.MSELoss().type(dtype).to(device)

        img_var = np_to_torch(img_np).type(dtype).to(device)
        mask_var = np_to_torch(img_mask_np).type(dtype).to(device)

        # Main loop
        def closure():

            # global i

            if param_noise:
                for n in [x for x in net.parameters() if len(x.size()) == 4]:
                    n = n + n.detach().clone().normal_() * n.std() / 50

            net_input = net_input_saved
            if reg_noise_std > 0:
                net_input = net_input_saved + (noise.normal_() * reg_noise_std)

            out = net(net_input)

            total_loss = mse(out * mask_var, img_var * mask_var)
            total_loss.backward()

            # print('Iteration %05d    Loss %f' % (i, total_loss.item()), '\r', end='')
            if self.i % save_every == 0 or self.i == num_iter - 1:
                out_np = torch_to_np(out)
                out_np = 255 * np.moveaxis(out_np, 0, 2)
                out_np = out_np.astype(np.uint8)
                filep = outp_path + str(self.i) + ".png"
                image = Image.fromarray(out_np)
                image.save(filep)

            self.i += 1

            return total_loss

        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()

        p = get_params(OPT_OVER, net, net_input)
        optimize(OPTIMIZER, p, closure, LR, num_iter)