def forward(self, x):
        yuv_x = kornia.rgb_to_yuv(x)

        x_chroma = yuv_x[:, 1:, ...]
        x_luma = yuv_x[:, 0:1, ...]

        chroma_features = self.chroma_resnet(x_chroma)
        luma_features = self.luma_resnet(x_luma)

        features = torch.cat([chroma_features, luma_features], dim=1)

        return self.fc(features)
Beispiel #2
0
def Pred():

    cv2.namedWindow("cnn")
    a = win32gui.FindWindow(None, "cnn")
    win32gui.SetWindowLong(a, win32con.GWL_STYLE, win32con.WS_POPUP)
    x = 0
    y = 0
    width = 1920
    height = 1080

    while True:

        start_time = time.perf_counter()

        img = fB.get(timeout=2)
        frame = img
        frame = torch.from_numpy(frame).to(dev).reshape(
            [1, 1080, 1920, 3]).permute(0, 3, 1, 2).float()  #nchw

        frameYUV = kornia.rgb_to_yuv(frame)
        frameY = frameYUV[:, 0, :, :].reshape([1, 1, 1080, 1920]).half()
        predY = model(frameY / 255) * 255

        f0 = predY.float()[0, 0, :, :]
        f1 = frameYUV[0, 1, :, :]
        f2 = frameYUV[0, 2, :, :]  #1080,1920
        pred = torch.stack([f0, f1, f2]).float()  # 3,1080,1920
        pred = kornia.yuv_to_rgb(pred).clip(0, 255).permute(1, 2, 0)[:, :,
                                                                     [2, 1, 0]]
        pred = torch.tensor(pred, dtype=torch.uint8).to('cpu').numpy()

        cv2.imshow('cnn', pred)
        win32gui.SetWindowPos(a, win32con.HWND_TOPMOST, x, y, width, height,
                              win32con.SWP_SHOWWINDOW)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        x1 = time.perf_counter() - start_time
        start_timeD = time.perf_counter()
Beispiel #3
0
#############################
# Create a batch of images
xb_bgr = torch.cat([x_bgr, hflip(x_bgr), vflip(x_bgr), rot180(x_bgr)])
imshow(xb_bgr)

#############################
# Convert BGR to RGB
xb_rgb = kornia.bgr_to_rgb(xb_bgr)
imshow(xb_rgb)

#############################
# Convert RGB to grayscale
# NOTE: image comes in torch.uint8, and kornia assumes floating point type
xb_gray = kornia.rgb_to_grayscale(xb_rgb.float() / 255.0)
imshow(xb_gray)

#############################
# Convert RGB to HSV
xb_hsv = kornia.rgb_to_hsv(xb_rgb.float() / 255.0)
imshow(xb_hsv[:, 2:3])

#############################
# Convert RGB to YUV
# NOTE: image comes in torch.uint8, and kornia assumes floating point type
yuv = kornia.rgb_to_yuv(xb_rgb.float() / 255.0)
y_channel = torchvision.utils.make_grid(yuv, nrow=2)[0, :, :]
plt.imshow(y_channel, cmap='gray', vmin=0, vmax=1)  # Displaying only y channel
plt.axis('off')
plt.show()
Beispiel #4
0
    def backward_G(self,epoch):

        self.loss_G_GAN_A = 0
        self.loss_G_GAN_B = 0
        self.loss_G_GAN_flash_B = 0
        self.loss_G_GAN_flash_A = 0
        # comment loss on recrations
        # self.loss_G_GAN_recA = 0
        # self.loss_G_GAN_recB = 0


        self.loss_cycle_A = 0
        self.loss_cycle_B = 0
        # self.loss_G_L1_A_comp = 0
        # self.loss_G_L1_B_comp = 0

        self.loss_G_L1_A_comp_color = 0
        self.loss_G_L1_B_comp_color = 0
        self.loss_color_dslr_A = 0
        self.loss_color_dslr_B = 0

        if self.opt.D_flash:
            fake_AC = torch.cat((self.real_A, self.flash_from_decomposition), 1)
            pred_fake = self.netD_Flash(fake_AC)
            self.loss_G_GAN_flash_A = self.criterionGAN(pred_fake, True)
            fake_BC = torch.cat((self.real_B, self.flash_from_generation), 1)
            pred_fake = self.netD_Flash(fake_BC)
            self.loss_G_GAN_flash_B = self.criterionGAN(pred_fake, True)
        else:
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            pred_fake = self.netD_Decompostion(fake_AB)
            self.loss_G_GAN_A = self.criterionGAN(pred_fake, True)
            fake_AB_B = torch.cat((self.real_B, self.fake_A), 1)
            pred_fake = self.netD_Generation(fake_AB_B)
            self.loss_G_GAN_B = self.criterionGAN(pred_fake, True)

            # fake_AB = torch.cat((self.real_A, self.rec_A), 1)
            # pred_fake = self.netD_Decompostion(fake_AB)
            # self.loss_G_GAN_recA = self.criterionGAN(pred_fake, True)
            # fake_AB_B = torch.cat((self.real_B, self.rec_B), 1)
            # pred_fake = self.netD_Generation(fake_AB_B)
            # self.loss_G_GAN_recB = self.criterionGAN(pred_fake, True)

        # ## Flash L1 loss
        # if self.opt.lambda_comp != 0:
        #     self.loss_G_L1_A_comp = self.criterionL1(self.flash_from_decomposition, self.real_C) * self.opt.lambda_comp
        #     self.loss_G_L1_B_comp = self.criterionL1(self.flash_from_generation, self.real_C) * self.opt.lambda_comp


        ## Flash Color Loss
        if self.opt.lambda_color_uv != 0:
            fake_C_A = kornia.rgb_to_yuv(self.flash_from_decomposition)[:,1:2,:,:]

            fake_C_B = kornia.rgb_to_yuv(self.flash_from_generation)[:,1:2,:,:]

            real_C_color = kornia.rgb_to_yuv(self.real_C)[:,1:2,:,:]

            self.loss_G_L1_A_comp_color = self.criterionL1(fake_C_A, real_C_color) * self.opt.lambda_color_uv
            self.loss_G_L1_B_comp_color = self.criterionL1(fake_C_B, real_C_color) * self.opt.lambda_color_uv


        if epoch >= self.opt.cycle_epoch:
            self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * self.opt.lambda_A
            self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * self.opt.lambda_B

        if self.opt.dslr_color_loss:
            real_A_blurred = kornia.gaussian_blur2d(self.real_A, (21, 21), (3, 3))
            real_B_blurred = kornia.gaussian_blur2d(self.real_B, (21, 21), (3, 3))
            fake_A_blurred = kornia.gaussian_blur2d(self.fake_A, (21, 21), (3, 3))
            fake_B_blurred = kornia.gaussian_blur2d(self.fake_B, (21, 21), (3, 3))
            self.loss_color_dslr_A = self.criterionL1(real_A_blurred, fake_A_blurred) * self.opt.dslr_color_loss
            self.loss_color_dslr_B = self.criterionL1(real_B_blurred, fake_B_blurred) * self.opt.dslr_color_loss


        self.loss_G_L1_A = self.criterionL1(self.fake_A, self.real_A) * self.opt.lambda_L1
        self.loss_G_L1_B = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

        self.loss_G =self.loss_color_dslr_A + self.loss_color_dslr_B + \
                     self.loss_G_L1_B_comp_color + self.loss_G_L1_A_comp_color +\
                     self.loss_G_GAN_flash_A + self.loss_G_GAN_flash_B +\
                     self.loss_G_GAN_B + self.loss_G_GAN_A +\
                     self.loss_cycle_A + self.loss_cycle_B +\
                     self.loss_G_L1_A+ self.loss_G_L1_B

        self.loss_G.backward()
Beispiel #5
0
 def test_rgb_to_yuv_shape_bad(self, bad_input_shapes):
     with pytest.raises(ValueError):
         out = kornia.rgb_to_yuv(torch.ones(*bad_input_shapes))
Beispiel #6
0
 def test_rgb_to_yuv_type(self):
     with pytest.raises(TypeError):
         out = kornia.rgb_to_yuv(1)
Beispiel #7
0
 def test_rgb_to_yuv_batch_shape(self, device):
     batch_size, channels, height, width = 2, 3, 4, 5
     img = torch.ones(batch_size, channels, height, width).to(device)
     assert kornia.rgb_to_yuv(img).shape == \
         (batch_size, channels, height, width)
Beispiel #8
0
 def test_rgb_to_yuv_shape(self, device):
     channels, height, width = 3, 4, 5
     img = torch.ones(channels, height, width).to(device)
     assert kornia.rgb_to_yuv(img).shape == (channels, height, width)