Exemplo n.º 1
0
def main():
    import sys
    args = get_FaceSR_opt()
    img_path = args.image
    img = utils.read_cv2_img(img_path)
    sr_model = SRGANModel(args, is_train=False)
    sr_model.load()

    def sr_forward(img, padding=0.5, moving=0.1):
        img_aligned, M = dlib_detect_face(img,
                                          padding=padding,
                                          image_size=(128, 128),
                                          moving=moving)
        input_img = torch.unsqueeze(_transform(Image.fromarray(img_aligned)),
                                    0)
        sr_model.var_L = input_img.to(sr_model.device)
        sr_model.test()
        output_img = sr_model.fake_H.squeeze(0).cpu().numpy()
        output_img = np.clip((np.transpose(output_img,
                                           (1, 2, 0)) / 2.0 + 0.5) * 255.0, 0,
                             255).astype(np.uint8)
        rec_img = face_recover(output_img, M * 4, img)
        return output_img, rec_img

    output_img, rec_img = sr_forward(img)
    utils.save_image(output_img, 'output_face.jpg')
    utils.save_image(rec_img, 'output_img.jpg')
    def init_model(self):
        """
        初始化模型
        """
        sr_model = SRGANModel(self.get_FaceSR_opt(), is_train=False)
        sr_model.load()
        print('[Info] device: {}'.format(sr_model.device))

        return sr_model
Exemplo n.º 3
0
def SR(img):
    cv2.imshow('img', img)
    cv2.waitKey(0)
    print("Start SR test")
    try:
        sr_model = SRGANModel(get_FaceSR_opt(), is_train=False)
    except Exception as e:
        print('no module', e)
    print(1)
    sr_model.load()
    print(2)

    in_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
    print(3)
    sr_model.var_L = in_img.to(sr_model.device)
    print(4)
    sr_model.test()
    print(5)
    #visuals = sr_model.fake_H.squeeze(0).cpu().numpy()
    visuals = sr_model.fake_H.detach().float().cpu()
    print(6)
    image_numpy = utils.tensor2im(visuals, show_size=224)
    print(7)
    image_numpy = np.reshape(image_numpy, (-1, 224, 3))
    print(8)
    #image_numpy = cv2.resize(image_numpy, (img.shape[0], img.shape[1]))
    print('End test')
    return image_numpy
Exemplo n.º 4
0
def main():

    cap = cv2.VideoCapture(0)
    print(set_res(cap, 640, 480))

    cv2.namedWindow('SR', cv2.WINDOW_NORMAL)

    # setting up the model

    _transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    sr_model = SRGANModel(get_FaceSR_opt(), is_train=False)
    sr_model.load()

    def sr_forward(img, padding=0.5, moving=0.1):
        # img_aligned, M = dlib_detect_face(img, padding=padding, image_size=img.shape[:-1], moving=moving)
        input_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
        sr_model.var_L = input_img.to(sr_model.device)
        sr_model.test()
        output_img = sr_model.fake_H.squeeze(0).cpu().numpy()
        output_img = np.clip((np.transpose(output_img,
                                           (1, 2, 0)) / 2.0 + 0.5) * 255.0, 0,
                             255).astype(np.uint8)
        # rec_img = face_recover(output_img, M * 4, img)
        return output_img

    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades +
                                         'haarcascade_frontalface_default.xml')

    def single_image_face_sr(image):

        # getting faces

        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)
        print("Faces:", len(faces))
        face_crops = []

        for x, y, w, h in faces:
            offsets = (int(w * border), int(h * border))  # w, h

            point = (max(0, x - offsets[0]), max(0, y - offsets[1]))

            face_crop_bgr = image[point[1]: point[1] + h + 2 * offsets[1], \
                                  point[0]: point[0] + w + 2 * offsets[0]]
            face_crop = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
            print(face_crop.shape)
            face_crops.append((point, face_crop))

        new_image_bgr = cv2.resize(image,
                                   (image.shape[1] * 4, image.shape[0] * 4))
        new_image = cv2.cvtColor(new_image_bgr, cv2.COLOR_BGR2RGB)

        for face in face_crops:
            face_image = face[1]
            position = face[0] + (face_image.shape)[:-1]

            output_img = sr_forward(face_image)
            new_dims = output_img.shape[:-1]
            new_image[4 * position[1]:4 * position[1] + new_dims[1], 4 *
                      position[0]:4 * position[0] + new_dims[0]] = output_img

        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)

        return new_image

    # setting up face detector

    while True:
        ret, img = cap.read()
        #         img = cv2.flip(img, -1)
        # the model x4 the faces, but runs too slowly on the full 1080p output, so
        # I 'cheat' and half the resolution, therefore the faces are still x2 the
        # resolution as normal in order to run even close to real-time.

        img = cv2.resize(img, (int(img.shape[1] / 2), int(img.shape[0] / 2)))
        print(img.shape)

        img_sr = single_image_face_sr(img)

        cv2.imshow('video', img_sr)

        k = cv2.waitKey(30) & 0xff
        if k == 27:  # press 'ESC' to quit
            break

    cap.release()
    cv2.destroyAllWindows()
Exemplo n.º 5
0
    # network D
    parser.add_argument('--which_model_D', type=str, default='discriminator_vgg_128')
    parser.add_argument('--D_in_nc', type=int, default=3)
    parser.add_argument('--D_nf', type=int, default=64)

    # data dir
    parser.add_argument('--pretrain_model_G', type=str, default='90000_G.pth')
    parser.add_argument('--pretrain_model_D', type=str, default=None)

    args = parser.parse_args()

    return args


sr_model = SRGANModel(get_FaceSR_opt(), is_train=False)
sr_model.load()

def sr_forward(img, padding=0.5, moving=0.1):
    img_aligned, M = dlib_detect_face(img, padding=padding, image_size=(128, 128), moving=moving)
    input_img = torch.unsqueeze(_transform(Image.fromarray(img_aligned)), 0)
    sr_model.var_L = input_img.to(sr_model.device)
    sr_model.test()
    output_img = sr_model.fake_H.squeeze(0).cpu().numpy()
    output_img = np.clip((np.transpose(output_img, (1, 2, 0)) / 2.0 + 0.5) * 255.0, 0, 255).astype(np.uint8)
    rec_img = face_recover(output_img, M * 4, img)
    return output_img, rec_img

img_path = 'input.jpg'
img = utils.read_cv2_img(img_path)
output_img, rec_img = sr_forward(img)
Exemplo n.º 6
0
def single_image_face_sr(img_path, image=None):

    # getting faces:
    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades +
                                         'haarcascade_frontalface_default.xml')

    if img_path != None:
        image = cv2.imread(img_path)
        print(img_path)
    dims = image.shape

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)
    print("Faces:", len(faces))
    face_crops = []

    for x, y, w, h in faces:
        offsets = (int(w * border), int(h * border))  # w, h

        point = (max(0, x - offsets[0]), max(0, y - offsets[1]))

        face_crop_bgr = image[point[1]: point[1] + h + 2 * offsets[1], \
                              point[0]: point[0] + w + 2 * offsets[0]]
        face_crop = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
        print(face_crop.shape)
        face_crops.append((point, face_crop))
    # doing super res
    _transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    sr_model = SRGANModel(get_FaceSR_opt(), is_train=False)
    sr_model.load()

    def sr_forward(img, padding=0.5, moving=0.1):
        # img_aligned, M = dlib_detect_face(img, padding=padding, image_size=img.shape[:-1], moving=moving)
        input_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
        sr_model.var_L = input_img.to(sr_model.device)
        sr_model.test()
        output_img = sr_model.fake_H.squeeze(0).cpu().numpy()
        output_img = np.clip((np.transpose(output_img,
                                           (1, 2, 0)) / 2.0 + 0.5) * 255.0, 0,
                             255).astype(np.uint8)
        # rec_img = face_recover(output_img, M * 4, img)
        return output_img

    # setting up the resized image

    new_image_bgr = cv2.resize(image, (image.shape[1] * 4, image.shape[0] * 4))
    new_image = cv2.cvtColor(new_image_bgr, cv2.COLOR_BGR2RGB)

    for face in face_crops:
        face_image = face[1]
        position = face[0] + (face_image.shape)[:-1]

        output_img = sr_forward(face_image)
        new_dims = output_img.shape[:-1]
        new_image[4 * position[1]:4 * position[1] + new_dims[1],
                  4 * position[0]:4 * position[0] + new_dims[0]] = output_img

    new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
    if img_path != None:
        output_path = img_path[:-4] + "_face_sr" + ".png"
        cv2.imwrite(output_path, new_image)
        print("done!")
    return new_image
Exemplo n.º 7
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7')

    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--dev_ratio', type=float, default=0.01)
    parser.add_argument('--lr_G', type=float, default=1e-4)
    parser.add_argument('--weight_decay_G', type=float, default=0)
    parser.add_argument('--beta1_G', type=float, default=0.9)
    parser.add_argument('--beta2_G', type=float, default=0.99)
    parser.add_argument('--lr_D', type=float, default=1e-4)
    parser.add_argument('--weight_decay_D', type=float, default=0)
    parser.add_argument('--beta1_D', type=float, default=0.9)
    parser.add_argument('--beta2_D', type=float, default=0.99)
    parser.add_argument('--lr_scheme', type=str, default='MultiStepLR')
    parser.add_argument('--niter', type=int, default=100000)
    parser.add_argument('--warmup_iter', type=int, default=-1)
    parser.add_argument('--lr_steps', type=list, default=[50000])
    parser.add_argument('--lr_gamma', type=float, default=0.5)
    parser.add_argument('--pixel_criterion', type=str, default='l1')
    parser.add_argument('--pixel_weight', type=float, default=1e-2)
    parser.add_argument('--feature_criterion', type=str, default='l1')
    parser.add_argument('--feature_weight', type=float, default=1)
    parser.add_argument('--gan_type', type=str, default='ragan')
    parser.add_argument('--gan_weight', type=float, default=5e-3)
    parser.add_argument('--D_update_ratio', type=int, default=1)
    parser.add_argument('--D_init_iters', type=int, default=0)

    parser.add_argument('--print_freq', type=int, default=100)
    parser.add_argument('--val_freq', type=int, default=1000)
    parser.add_argument('--save_freq', type=int, default=10000)
    parser.add_argument('--crop_size', type=float, default=0.85)
    parser.add_argument('--lr_size', type=int, default=128)
    parser.add_argument('--hr_size', type=int, default=512)

    # network G
    parser.add_argument('--which_model_G', type=str, default='RRDBNet')
    parser.add_argument('--G_in_nc', type=int, default=3)
    parser.add_argument('--out_nc', type=int, default=3)
    parser.add_argument('--G_nf', type=int, default=64)
    parser.add_argument('--nb', type=int, default=16)

    # network D
    parser.add_argument('--which_model_D',
                        type=str,
                        default='discriminator_vgg_128')
    parser.add_argument('--D_in_nc', type=int, default=3)
    parser.add_argument('--D_nf', type=int, default=32)

    # data dir
    parser.add_argument('--hr_path',
                        type=list,
                        default=['data/celebahq-512/', 'data/ffhq-512/'])
    parser.add_argument('--lr_path', type=str, default='data/lr-128/')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='check_points/ESRGAN-V1/')
    parser.add_argument('--val_dir', type=str, default='dev_show')
    parser.add_argument('--training_state',
                        type=str,
                        default='check_points/ESRGAN-V1/state/')

    # resume the training
    parser.add_argument('--resume_state', type=str, default=None)
    parser.add_argument('--pretrain_model_G', type=str, default=None)
    parser.add_argument('--pretrain_model_D', type=str, default=None)

    parser.add_argument('--setting_file', type=str, default='setting.txt')
    parser.add_argument('--log_file', type=str, default='log.txt')
    args = check_args(parser.parse_args())
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    #### loading resume state if exists
    if args.resume_state is not None:
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            args.resume_state,
            map_location=lambda storage, loc: storage.cuda(device_id))
    else:
        resume_state = None

    # load dataset
    total_img_list = []
    for hr_path in args.hr_path:
        total_img_list.extend(glob(hr_path + '/*'))

    random.shuffle(total_img_list)
    dev_list = total_img_list[:int(len(total_img_list) * args.dev_ratio)]
    train_list = total_img_list[int(len(total_img_list) * args.dev_ratio):]

    train_loader = create_dataloader(args,
                                     train_list,
                                     is_train=True,
                                     n_threads=len(args.gpu_ids.split(',')))
    dev_loader = create_dataloader(args,
                                   dev_list,
                                   is_train=False,
                                   n_threads=len(args.gpu_ids.split(',')))

    #### create model
    model = SRGANModel(args, is_train=True)
    if resume_state is not None:
        model.load()

    #### resume training
    if resume_state is not None:
        print('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    total_epochs = int(math.ceil(args.niter / len(train_loader)))

    #### training
    print('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > args.niter:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=args.warmup_iter)

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % args.print_freq == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                print(message)

            # validation
            if current_step % args.val_freq == 0:
                show_dir = os.path.join(args.checkpoint_dir, 'show_dir')
                os.makedirs(show_dir, exist_ok=True)
                dev_data = None
                for val_data in dev_loader:
                    dev_data = val_data
                    break

                model.feed_data(dev_data)
                model.test()

                visuals = model.get_current_visuals()
                display_online_results(visuals,
                                       current_step,
                                       show_dir,
                                       show_size=args.hr_size)

            #### save models and training states
            if current_step % args.save_freq == 0:
                print('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')