def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent=parent)
        self.setupUi(self)
        self.setWindowIcon(QIcon('Software GUI/beauty.ico'))
        sys.stdout = EmittingStr(textWritten=self.outputWritten)
        sys.stderr = EmittingStr(textWritten=self.outputWritten)
        self.pushButton.clicked.connect(self.load_source)
        self.pushButton_2.clicked.connect(self.load_target)
        self.pushButton_3.clicked.connect(self.play_input_video)
        self.pushButton_4.clicked.connect(self.play_result_video)
        self.pushButton_5.clicked.connect(self.save_result)
        self.pushButton_skyrpl.clicked.connect(self.skyrpl)

        # self.checkpoint_load = 'test6_lovasz_1e-2/checkpoint_19_epoch.pkl'
        # self.checkpoint_load = 'test6_lovasz_1e-2/bestdice_min_38.57%_checkpoint_55_epoch.pkl'
        # self.checkpoint_load = 'test4_lovasz_1e-2/bestdice_min_47.90%_checkpoint_35_epoch.pkl'
        self.checkpoint_load = 'tools/checkpoint_199_epoch.pkl'
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.net = UNet(
            in_channels=3, out_channels=1,
            init_features=32)  # init_features is 64 in stander uent
        self.net.to(self.device)
        self.net.eval()
        if self.checkpoint_load is not None:
            checkpoint = torch.load(self.checkpoint_load)
            self.net.load_state_dict(checkpoint['model_state_dict'])
            print(
                '\nWelcome to use Magic Sky Software. \nPytorch model loads checkpoint from %s'
                % self.checkpoint_load)
        else:
            raise Exception("\nPlease specify the checkpoint")
        set_seed()  # 设置随机种子
Example #2
0
def get_model(m_path):

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    checkpoint = torch.load(m_path, map_location="cpu")

    # remove module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    unet.load_state_dict(new_state_dict)

    return unet
class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent=parent)
        self.setupUi(self)
        self.setWindowIcon(QIcon('Software GUI/beauty.ico'))
        sys.stdout = EmittingStr(textWritten=self.outputWritten)
        sys.stderr = EmittingStr(textWritten=self.outputWritten)
        self.pushButton.clicked.connect(self.load_source)
        self.pushButton_2.clicked.connect(self.load_target)
        self.pushButton_3.clicked.connect(self.play_input_video)
        self.pushButton_4.clicked.connect(self.play_result_video)
        self.pushButton_5.clicked.connect(self.save_result)
        self.pushButton_skyrpl.clicked.connect(self.skyrpl)

        # self.checkpoint_load = 'test6_lovasz_1e-2/checkpoint_19_epoch.pkl'
        # self.checkpoint_load = 'test6_lovasz_1e-2/bestdice_min_38.57%_checkpoint_55_epoch.pkl'
        # self.checkpoint_load = 'test4_lovasz_1e-2/bestdice_min_47.90%_checkpoint_35_epoch.pkl'
        self.checkpoint_load = 'tools/checkpoint_199_epoch.pkl'
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.net = UNet(
            in_channels=3, out_channels=1,
            init_features=32)  # init_features is 64 in stander uent
        self.net.to(self.device)
        self.net.eval()
        if self.checkpoint_load is not None:
            checkpoint = torch.load(self.checkpoint_load)
            self.net.load_state_dict(checkpoint['model_state_dict'])
            print(
                '\nWelcome to use Magic Sky Software. \nPytorch model loads checkpoint from %s'
                % self.checkpoint_load)
        else:
            raise Exception("\nPlease specify the checkpoint")
        set_seed()  # 设置随机种子

    def outputWritten(self, text):
        cursor = self.textBrowser.textCursor()
        cursor.movePosition(QtGui.QTextCursor.End)
        cursor.insertText(text)
        self.textBrowser.setTextCursor(cursor)
        self.textBrowser.ensureCursorVisible()

    def load_source(self):
        print('Loading source file')
        if self.modetext.currentText() == 'Photo':
            file_filter = r"IMAGE(*.jpg;*.jpeg;*.png);;ALL FILE(*)"
        else:
            file_filter = r"VIDEO(*.mp4);;ALL FILE(*)"
        self.srcname, file_type = QFileDialog.getOpenFileName(
            self, caption='source', directory='Demo', filter=file_filter)
        if self.srcname != "":
            print('Load source successfully: {0}'.format(self.srcname))
            self.sourcetext.setText(self.srcname)
            print('Loaded source file')
        else:
            print('Load failed.')
        if self.modetext.currentText() == 'Photo' and (self.srcname != ''):
            self.src = cv2.imread(self.srcname)
            self.src = cv2.cvtColor(self.src, cv2.COLOR_BGR2RGB)
            self.src_h, self.src_w, self.src_c = self.src.shape
            self.scene_show(self.srcname, self.sourceView)
        elif self.modetext.currentText() == 'Video' and (self.srcname != ''):
            cap = cv2.VideoCapture(self.srcname)
            success, frame = cap.read()
            assert success
            cv2.imwrite('temp/frame1.jpg', frame)
            cap.release()
            self.scene_show('temp/frame1.jpg', self.sourceView)

    def load_target(self):
        print('Loading target file')
        file_filter = r"IMAGE(*.jpg;*.jpeg;*.png);;ALL FILE(*)"
        self.tgtname, file_type = QFileDialog.getOpenFileName(
            self, caption='target', directory='sky', filter=file_filter)
        if self.tgtname != "":
            print('Load target successfully: {0}'.format(self.tgtname))
            self.targettext.setText(self.tgtname)
            print('Loaded target file')
        else:
            print('Load failed.')
        if self.tgtname != '':
            self.tgt = cv2.imread(self.tgtname)
            self.tgt = cv2.cvtColor(self.tgt, cv2.COLOR_BGR2RGB)
            self.tgt_h, self.tgt_w, self.tgt_c = self.tgt.shape
            self.scene_show(self.tgtname, self.targetView)

    def scene_show(self, filename, graphic_view):
        assert isinstance(graphic_view, QtWidgets.QGraphicsView)
        h, w, c = cv2.imread(filename).shape
        frame = QtGui.QImage(filename)
        pix = QtGui.QPixmap.fromImage(frame)
        item = QGraphicsPixmapItem(pix)
        scence = QGraphicsScene()
        scale = min(graphic_view.height() / (1.02 * h),
                    graphic_view.width() / (1.02 * w))
        item.setScale(scale)
        scence.addItem(item)
        graphic_view.setScene(scence)
        return

    def skyrpl(self):
        print('Replacing Sky')
        if self.modetext.currentText() == 'Photo':
            result = photo_replace(self.src, self.tgt, self.net)
            cv2.imwrite('temp/results.jpg', result[:, :, ::-1])
            self.scene_show('temp/results.jpg', self.resultView)
            print('Sky Replaced')
        elif self.modetext.currentText() == 'Video':
            self.cap = cv2.VideoCapture(self.srcname)
            fps = self.cap.get(cv2.CAP_PROP_FPS)
            frameCount = self.cap.get(cv2.CAP_PROP_FRAME_COUNT)
            size = (int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                    int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            self.videoWriter = cv2.VideoWriter('temp/result.mp4',
                                               cv2.VideoWriter_fourcc(*'mp4v'),
                                               fps, size)
            assert self.cap.isOpened()

            tic = time.time()
            # def video_thread(frameCount, tic):
            for index in range(int(frameCount)):
                success, frame = self.cap.read()
                if success:
                    result = photo_replace(frame[..., ::-1], self.tgt,
                                           self.net, 1)
                    if index == 0:
                        cv2.imwrite('temp/frame1.jpg', frame)
                        cv2.imwrite('temp/result1.jpg', result[..., ::-1])
                        self.scene_show('temp/frame1.jpg', self.sourceView)
                        self.scene_show('temp/result1.jpg', self.resultView)
                    if index % 50 == 0:
                        print("Replace %d, time %.2f" % (index,
                                                         (time.time() - tic)))
                    self.videoWriter.write(result[..., ::-1])
                else:
                    assert success
            self.cap.release()
            self.videoWriter.release()
            print("Infer Done! Time %.2f" % (time.time() - tic))

    def play_input_video(self):
        assert self.modetext.currentText() == 'Video'
        print('Play input video')
        global playmode
        playmode = "Input"
        global videoName  # 在这里设置全局变量以便在线程中使用
        videoName = self.srcname
        # cap = cv2.VideoCapture(str(videoName))
        self.th = Thread(self)
        self.th.changeSrcPixmap.connect(self.setInputImage)
        self.th.start()

    def play_result_video(self):
        assert self.modetext.currentText() == 'Video'
        print('Play result video')
        global playmode
        playmode = "Result"
        global videoName  # 在这里设置全局变量以便在线程中使用
        videoName = "temp/result.mp4"
        # cap = cv2.VideoCapture(str(videoName))
        self.th = Thread(self)
        self.th.changeResPixmap.connect(self.setResImage)
        self.th.start()

    def setResImage(self, Qframe):
        pix = QtGui.QPixmap.fromImage(Qframe)
        item = QGraphicsPixmapItem(pix)
        scence = QGraphicsScene()
        scale = min(self.resultView.height() / (1.02 * pix.height()),
                    self.resultView.width() / (1.02 * pix.width()))
        item.setScale(scale)
        scence.addItem(item)
        self.resultView.setScene(scence)

    def setInputImage(self, Qframe):
        pix = QtGui.QPixmap.fromImage(Qframe)
        item = QGraphicsPixmapItem(pix)
        scence = QGraphicsScene()
        scale = min(self.sourceView.height() / (1.02 * pix.height()),
                    self.sourceView.width() / (1.02 * pix.width()))
        item.setScale(scale)
        scence.addItem(item)
        self.sourceView.setScene(scence)

    def save_result(self):
        print('Saving result.')
        if self.modetext.currentText() == 'Photo':
            file_filter = r"(*.jpg)"
        else:
            file_filter = r"(*.mp4)"
        save_filename, filetype = QFileDialog.getSaveFileName(
            self,
            caption="Save result to: ",
            directory='results',
            filter=file_filter)
        print(save_filename, ' ', filetype, ' ')
        if self.modetext.currentText() == 'Photo':
            self.mycopyfile("temp/results.jpg", save_filename)
        else:
            self.mycopyfile("temp/result.mp4", save_filename)

    def mycopyfile(self, srcfile, dstfile):
        assert srcfile.endswith(('mp4', 'jpg'))
        fpath, fname = os.path.split(dstfile)  # 分离文件名和路径
        if not os.path.exists(fpath):
            os.makedirs(fpath)  # 创建路径
        shutil.copyfile(srcfile, dstfile)  # 复制文件
        print("Results saved -> %s" % (dstfile))
Example #4
0
    # step 1 划分训练集、验证集
    trainset = SkyDataset(trainset_path)
    testset = SkyDataset(testset_path)

    train_loader = DataLoader(trainset,
                              batch_size=BATCH_SIZE,
                              drop_last=False,
                              shuffle=True)
    valid_loader = DataLoader(testset,
                              batch_size=1,
                              drop_last=False,
                              shuffle=False)

    # step 2
    net = UNet(in_channels=3, out_channels=1,
               init_features=32)  # init_features is 64 in stander uent

    net.to(device)

    # step 3
    # loss_fn = nn.MSELoss()
    loss_fn = lovasz_hinge

    # step 4
    optimizer = optim.SGD(net.parameters(),
                          momentum=0.9,
                          lr=LR,
                          weight_decay=1e-2)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
Example #5
0
def unet_infer(demo_path_img, demo, save_result):
    """

    Args:
        demo_path_img:
        demo:
        save_result:

    Returns:

    """
    # demo = True
    # demo_path_img = 'd:/MyLearning/DIP/Final_Project/Unet/Demo/1.jpg'
    # save_result = True

    testset_path = os.path.join("dataset/testset")
    checkpoint_load = 'tools/checkpoint_199_epoch.pkl'
    shuffle_dataset = True

    vis_num = 1000
    mask_thres = 0.5
    ##########################################################

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed()  # 设置随机种子
    in_size = 224

    if not demo:
        testset = SkyDataset(testset_path)
        valid_loader = DataLoader(testset,
                                  batch_size=1,
                                  drop_last=False,
                                  shuffle=False)
    else:
        img_pil = Image.open(demo_path_img).convert('RGB')
        original_img = np.array(img_pil)
        w, h = img_pil.size
        img_pil = img_pil.resize((in_size, in_size), Image.BILINEAR)

        img_hwc = np.array(img_pil)
        img_chw = img_hwc.transpose((2, 0, 1))
        img_chw = torch.from_numpy(img_chw).float()

    net = UNet(in_channels=3, out_channels=1,
               init_features=32)  # init_features is 64 in stander uent
    net.to(device)
    if checkpoint_load is not None:
        path_checkpoint = checkpoint_load
        checkpoint = torch.load(path_checkpoint)

        net.load_state_dict(checkpoint['model_state_dict'])
        print('load checkpoint from %s' % path_checkpoint)
    else:
        raise Exception("\nPlease specify the checkpoint")

    net.eval()
    with torch.no_grad():
        if not demo:
            for idx, (inputs, labels) in enumerate(valid_loader):
                if idx > vis_num:
                    break
                if torch.cuda.is_available():
                    inputs, labels = inputs.to(device), labels.to(device)
                outputs = net(inputs)

                pred = (outputs.cpu().data.numpy() * 255).astype("uint8")
                pred_gray = pred.squeeze()

                mask_pred = outputs.ge(mask_thres).cpu().data.numpy()
                mask_pred_gray = (mask_pred.squeeze() * 255).astype("uint8")

                print('idx>>%d, Dice>>%.4f' %
                      (idx, compute_dice(mask_pred,
                                         labels.cpu().numpy())))
                img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose(
                    (1, 2, 0)).astype("uint8")
                img_label = (labels.cpu().data.numpy()[0, 0, :, :] *
                             255).astype("uint8")
                plt.subplot(221).imshow(img_hwc)
                plt.title('%d Original IMG' % idx)
                plt.subplot(222).imshow(img_label, cmap="gray")
                plt.title('%d Original Label' % idx)
                plt.subplot(223).imshow(mask_pred_gray, cmap="gray")
                plt.title('%d Binary Label' % idx)
                plt.subplot(224).imshow(pred_gray, cmap="gray")
                plt.title('%d Raw Label' % idx)
                plt.tight_layout()
                plt.savefig('results/%d_img' % idx)
                plt.show()
                plt.close()
                if save_result:
                    pred_gray_img = Image.fromarray(pred_gray)
                    pred_gray_img.save('results/%d_pred_gray_img.png' % idx)

                    img_hwc_img = Image.fromarray(img_hwc)
                    img_hwc_img.save('results/%d_img_hwc.png' % idx)
        else:
            inputs = img_chw.to(device).unsqueeze(0)
            outputs = net(inputs)

            pred = (outputs.cpu().data.numpy() * 255).astype("uint8")
            pred_gray = pred.squeeze()

            mask_pred = outputs.ge(mask_thres).cpu().data.numpy()
            mask_pred_gray = (mask_pred.squeeze() * 255).astype("uint8")

            img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose(
                (1, 2, 0)).astype("uint8")

            if save_result:
                pred_gray_img = Image.fromarray(pred_gray)
                pred_gray_img = pred_gray_img.resize((w, h), Image.BICUBIC)
                pred_gray_img.save(
                    'd:/MyLearning/DIP/Final_Project/Unet/results/1_pred_gray_img.png'
                )
                mask_pred_gray_img = Image.fromarray(mask_pred_gray)
                mask_pred_gray_img = mask_pred_gray_img.resize((w, h),
                                                               Image.BICUBIC)
                mask_pred_gray_img.save(
                    'd:/MyLearning/DIP/Final_Project/Unet/results/1_mask_pred_gray_img.png'
                )
                img_hwc_img = Image.open(demo_path_img).convert('RGB')
                img_hwc_img.save(
                    'd:/MyLearning/DIP/Final_Project/Unet/results/1_img_hwc_img.png'
                )
            # plt.subplot(131).imshow(img_hwc)
            # plt.subplot(132).imshow(mask_pred_gray, cmap="gray")
            # plt.subplot(133).imshow(pred_gray, cmap="gray")
            # plt.show()
            # plt.pause(0.5)
            # plt.close()

            # img_hwc = Image.fromarray(img_hwc)
            # img_hwc = img_hwc.resize((w, h), Image.BILINEAR)
            # img_hwc = np.array(img_hwc)
            mask_pred_gray = Image.fromarray(mask_pred_gray)
            mask_pred_gray = mask_pred_gray.resize((w, h), Image.BILINEAR)
            mask_pred_gray = np.array(mask_pred_gray)

            return original_img, mask_pred_gray
Example #6
0
def video_infer(img_pil):
    """

    Args:
        img_pil:

    Returns:

    """
    checkpoint_load = 'tools/checkpoint_199_epoch.pkl'

    vis_num = 1000
    mask_thres = 0.5
    ##########################################################

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed()  # 设置随机种子
    in_size = 224
    # img_pil = demo_img.convert('RGB')
    original_img = np.array(img_pil)
    h, w, _ = img_pil.shape
    # img_pil = img_pil.resize((in_size, in_size), Image.BILINEAR)
    img_pil = cv2.resize(img_pil, (in_size, in_size),
                         interpolation=cv2.INTER_AREA)

    img_hwc = np.array(img_pil)
    img_chw = img_hwc.transpose((2, 0, 1))
    img_chw = torch.from_numpy(img_chw).float()

    net = UNet(in_channels=3, out_channels=1,
               init_features=32)  # init_features is 64 in stander uent
    net.to(device)
    if checkpoint_load is not None:
        path_checkpoint = checkpoint_load
        checkpoint = torch.load(path_checkpoint)

        net.load_state_dict(checkpoint['model_state_dict'])
        # print('load checkpoint from %s' % path_checkpoint)
    else:
        raise Exception("\nPlease specify the checkpoint")

    net.eval()
    with torch.no_grad():

        inputs = img_chw.to(device).unsqueeze(0)
        outputs = net(inputs)

        pred = (outputs.cpu().data.numpy() * 255).astype("uint8")
        pred_gray = pred.squeeze()

        mask_pred = outputs.ge(mask_thres).cpu().data.numpy()
        mask_pred_gray = (mask_pred.squeeze() * 255).astype("uint8")

        img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose(
            (1, 2, 0)).astype("uint8")

    mask_pred_gray = Image.fromarray(mask_pred_gray)
    mask_pred_gray = mask_pred_gray.resize((w, h), Image.BILINEAR)
    mask_pred_gray = np.array(mask_pred_gray)

    return original_img, mask_pred_gray
    checkpoint_interval = 20
    vis_num = 10
    mask_thres = 0.5

    train_dir = os.path.join("D:/学习/人工智能/PyTorch学习/课程代码与作业/08-02-数据-PortraitDataset", "PortraitDataset", "train")
    valid_dir = os.path.join("D:/学习/人工智能/PyTorch学习/课程代码与作业/08-02-数据-PortraitDataset", "PortraitDataset", "valid")

    # step 1
    train_set = PortraitDataset(train_dir)
    valid_set = PortraitDataset(valid_dir)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

    # step 2
    net = UNet(in_channels=3, out_channels=1, init_features=32)   # init_features is 64 in stander uent
    net.to(device)

    # step 3
    loss_fn = nn.MSELoss()
    # step 4
    optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)

    # step 5
    train_curve = list()
    valid_curve = list()
    train_dice_curve = list()
    valid_dice_curve = list()
    for epoch in range(start_epoch, max_epoch):