예제 #1
0
    def __init__(self, conf):
        os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

        # Acquire configuration
        self.conf = conf
        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Define the GAN
        self.G = networks.Generator(conf)
        self.D = networks.Discriminator(conf)

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            print("gpu num : ", torch.cuda.device_count())
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)
        print("haha, gpu num : ", torch.cuda.device_count())
        self.G.to(self._device)
        self.D.to(self._device)

        # Calculate D's input & output shape according to the shaving done by the networks
        if torch.cuda.device_count() > 1:
            self.d_input_shape = self.G.module.output_size
            self.d_output_shape = self.d_input_shape - self.D.module.forward_shave
        else:
            self.d_input_shape = self.G.output_size
            self.d_output_shape = self.d_input_shape - self.D.forward_shave

        # Input tensors
        self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda()
        self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda()

        # The kernel G is imitating
        self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda()

        # Losses
        self.GAN_loss_layer = loss.GANLoss(d_last_layer_size=self.d_output_shape).cuda()
        self.bicubic_loss = loss.DownScaleLoss(scale_factor=conf.scale_factor).cuda()
        self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
        self.boundaries_loss = loss.BoundariesLoss(k_size=conf.G_kernel_size).cuda()
        self.centralized_loss = loss.CentralizedLoss(k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda()
        self.sparse_loss = loss.SparsityLoss().cuda()
        self.loss_bicubic = 0

        # Define loss function
        self.criterionGAN = self.GAN_loss_layer.forward

        # Initialize networks weights
        self.G.apply(networks.weights_init_G)
        self.D.apply(networks.weights_init_D)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999))

        print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
예제 #2
0
    def __init__(self, conf):
        # Acquire configuration
        self.conf = conf

        # Define the GAN
        self.G = networks.Generator(conf).cuda()
        self.D = networks.Discriminator(conf).cuda()

        # Calculate D's input & output shape according to the shaving done by the networks
        self.d_input_shape = self.G.output_size
        self.d_output_shape = self.d_input_shape - self.D.forward_shave

        # Input tensors
        self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size,
                                         conf.input_crop_size).cuda()
        self.d_input = torch.FloatTensor(1, 3, self.d_input_shape,
                                         self.d_input_shape).cuda()

        # The kernel G is imitating
        self.curr_k = torch.FloatTensor(conf.G_kernel_size,
                                        conf.G_kernel_size).cuda()

        # Losses
        self.GAN_loss_layer = loss.GANLoss(
            d_last_layer_size=self.d_output_shape).cuda()
        self.bicubic_loss = loss.DownScaleLoss(
            scale_factor=conf.scale_factor).cuda()
        self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
        self.boundaries_loss = loss.BoundariesLoss(
            k_size=conf.G_kernel_size).cuda()
        self.centralized_loss = loss.CentralizedLoss(
            k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda()
        self.sparse_loss = loss.SparsityLoss().cuda()
        self.loss_bicubic = 0

        # Define loss function
        self.criterionGAN = self.GAN_loss_layer.forward

        # Initialize networks weights
        self.G.apply(networks.weights_init_G)
        self.D.apply(networks.weights_init_D)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                            lr=conf.g_lr,
                                            betas=(conf.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                            lr=conf.d_lr,
                                            betas=(conf.beta1, 0.999))

        self.iteration = 0  # for tensorboard
        # self.ground_truth_kernel = np.loadtxt(conf.ground_truth_kernel_path)
        # writer.add_image("ground_truth_kernel", (self.ground_truth_kernel - np.min(self.ground_truth_kernel)) / (np.max(self.ground_truth_kernel - np.min(self.ground_truth_kernel))), 0, dataformats="HW")

        print('*' * 60 +
              '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
예제 #3
0
    def __init__(self, conf, args, input, ch_indicator):
        # acquire the configuration
        self.conf = conf
        self.args = args

        # define the generator
        self.G = Generator(self.conf).cuda()

        # batch size of the network
        self.batch_size = self.conf.batch_size

        # Initialize a kernel K for emitating
        self.curr_k = torch.FloatTensor(self.conf.G_kernel_size,
                                        self.conf.G_kernel_size).cuda()

        # Losses
        self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
        self.boundaries_loss = loss.BoundariesLoss(
            k_size=self.conf.G_kernel_size).cuda()
        self.centralized_loss = loss.CentralizedLoss(
            k_size=self.conf.G_kernel_size,
            scale_factor=self.conf.scale_factor).cuda()
        self.sparse_loss = loss.SparsityLoss().cuda()

        # Constraint co-efficients
        self.lambda_sum2one = self.conf.lambda_sum2one
        self.lambda_boundaries = self.conf.lambda_boundaries

        self.loss_L1_sum = 0
        self.loss_boundaries_sum = 0
        self.loss_sum2one_sum = 0
        self.total_loss_L1_sum = 0

        # Define loss function
        self.criterionMSE = torch.nn.MSELoss().cuda()
        self.criterionL1 = torch.nn.L1Loss().cuda()

        # Initialize the weights of networks
        self.G.apply(weights_init_G)
        # self.D.apply(weights_init_D)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                            lr=self.conf.g_lr,
                                            betas=(self.conf.beta1, 0.999))

        # test image used for output the results after convolution
        self.test_input_ori = input
        self.test_input_val = mpimg.imread(
            os.path.join(self.conf.v_input_dir, 'validation_img.png'))
        ########################################################################
        if (ch_indicator == 0) or (ch_indicator == 1):
            self.color_idx = 'r'
        elif ch_indicator == 2:
            self.color_idx = 'g'
        elif ch_indicator == 3:
            self.color_idx = 'b'
        self.wb = args.white_balance
        ########################################################################
        self.test_input_ori = torch.FloatTensor(self.test_input_ori)
        self.test_input_val = torch.FloatTensor(self.test_input_val)
        self.test_input_ori = F.pad(
            self.test_input_ori,
            (self.conf.G_kernel_size - 1, self.conf.G_kernel_size - 1,
             self.conf.G_kernel_size - 1, self.conf.G_kernel_size - 1),
            "constant",
            value=0).unsqueeze(0).unsqueeze(0).cuda()
        self.test_input_val = F.pad(
            self.test_input_val,
            (self.conf.G_kernel_size - 1, self.conf.G_kernel_size - 1,
             self.conf.G_kernel_size - 1, self.conf.G_kernel_size - 1),
            "constant",
            value=0).unsqueeze(0).unsqueeze(0).cuda()

        self.test_input_ori = unprocess(self.test_input_ori, self.wb,
                                        self.color_idx)
        self.test_input_val = unprocess(self.test_input_val, self.wb,
                                        self.color_idx)

        self.curr_img_ori = torch.zeros_like(self.test_input_ori).cuda()
        self.curr_img_val = torch.zeros_like(self.test_input_val).cuda()