示例#1
0
 def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
     input_image, reference_image = batch
     forged_image = self.generator(input_image)
     forged_image = tensor2im(forged_image)
     input_image = tensor2im(input_image)
     reference_image = tensor2im(reference_image)
     if batch_idx % 50 == 0:
         tensorboard = self.logger.experiment
         tensorboard.add_image("Forged", forged_image.cpu().detach())
         tensorboard.add_image("Input", input_image.cpu().detach())
         tensorboard.add_image("Reference", reference_image.cpu().detach())
     self.forged_images[batch_idx] = forged_image
     self.reference_images[batch_idx] = reference_image
示例#2
0
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see htmlu.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = utils.tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        utils.save_image(im, save_path, aspect_ratio=aspect_ratio)
        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)
示例#3
0
def main():
    # ---------------------------Test ---------------------------------
    print("Start SR test")
    img = utils.read_cv2_img(i)
    #img = img.resize((128,128))
    in_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
    sr_model.var_L = in_img.to(sr_model.device)
    sr_model.test()
    #visuals = sr_model.fake_H.squeeze(0).cpu().numpy()
    visuals = sr_model.fake_H.detach().float().cpu()
    image_numpy = utils.tensor2im(visuals, show_size=317)
    image_numpy = np.reshape(image_numpy, (-1, 317, 3))
    image_numpy = cv2.resize(image_numpy, (img.shape[0], img.shape[1]))
    print('End test')
    print()
    # ----------------------------------End test--------------------------

    # -----------------------------SR img combine Original img --------------------------------------------
    #row,cols, channels = crop_img.shape
    start_x = result[i][j]['box'][0]
    start_y = result[i][j]['box'][1]
    end_x = result[i][j]['box'][0] + result[i][j]['box'][2]
    end_y = result[i][j]['box'][1] + result[i][j]['box'][3]
    area = (start_x, start_y, end_x, end_y)
    crop_img = crop_img.resize((int(end_x - start_x), int(end_y - start_y)))

    px = f_img.load()
    c_px = crop_img.load()
    c_x = 0
    c_x_max = crop_img.width
    c_y = 0
    c_y_max = crop_img.height
    # print(crop_img.width,crop_img.height)
    # print(end_x-start_x,end_y-start_y)
    for q in range(start_x, end_x):
        c_y = 0
        for k in range(start_y, end_y):
            # print(c_px[c_x,c_y])
            try:
                px[q, k] = c_px[c_x, c_y]
                if (c_y < c_y_max):
                    c_y = c_y + 1
            except:
                if (c_y < c_y_max):
                    c_y = c_y + 1
                pass
        if (c_x < c_x_max):
            c_x = c_x + 1
    my_count = my_count + 1
    # f_img.save('/content/drive/MyDrive/Face-Super-Resolution/Video_test/final/{}.png'.format(i.split('/')[-1].split('.')[0]))
    f_img.save(args.final_path +
               '/{}.png'.format(i.split('/')[-1].split('.')[0]))
    r_c = r_c + 1
    # cv2.imwrite("/content/drive/MyDrive/PARK/IP-FSRGAN/final/{}.png".format(r_c),f_img)
    print("End SR img combine to ori img")
    print()
    # -----------------------------SR img combine Original img End --------------------------------------------
    end_time = time.time()

    print("time = {}".format(end_time - start_time))
示例#4
0
def main():

    # Create Result Directory
    os.makedirs('./results/predict', exist_ok=True)

    # Get Arguments
    args = args_initialize()

    # Define Model
    net_G = ResNetGenerator(
        input_nc=args.input_nc,
        output_nc=args.output_nc,
        ngf=args.ngf,
        n_blocks=9
    )

    # Load Weights
    state_dict = torch.load('./latest_net_G.pth', map_location='cpu')
    net_G.load_state_dict(state_dict)

    # Create Tensor from Image file
    im_file = args.imfile
    tensor_img = utils.create_data(im_file)

    # Predict
    outputs = net_G.forward(tensor_img)[0]

    # Convert Output Tensor to Image file
    im = utils.tensor2im(outputs)
    file_name = os.path.basename(im_file)
    save_path = os.path.join('./results/predict', 'horse2zebra_' + str(file_name) + '.png')
    utils.save_image(im, save_path)
示例#5
0
def SR(img):
    cv2.imshow('img', img)
    cv2.waitKey(0)
    print("Start SR test")
    try:
        sr_model = SRGANModel(get_FaceSR_opt(), is_train=False)
    except Exception as e:
        print('no module', e)
    print(1)
    sr_model.load()
    print(2)

    in_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0)
    print(3)
    sr_model.var_L = in_img.to(sr_model.device)
    print(4)
    sr_model.test()
    print(5)
    #visuals = sr_model.fake_H.squeeze(0).cpu().numpy()
    visuals = sr_model.fake_H.detach().float().cpu()
    print(6)
    image_numpy = utils.tensor2im(visuals, show_size=224)
    print(7)
    image_numpy = np.reshape(image_numpy, (-1, 224, 3))
    print(8)
    #image_numpy = cv2.resize(image_numpy, (img.shape[0], img.shape[1]))
    print('End test')
    return image_numpy
示例#6
0
def translate_patch(hybrid, cmask, patch_id, loc, model):

    transform = get_transform(opt)
    hybrid_img = utils.ndarrayToPilImage(hybrid)
    hybrid_img = transform(hybrid_img)
    model_input = {'unstyled': hybrid_img, 'hybrid': hybrid_img, 'mask': cmask}
    model.set_input(model_input)
    model.test()
    result = model.fake
    im = utils.tensor2im(result)
    cres = np.asarray(im)

    return cres
示例#7
0
    def get_current_visuals(self):
        real_A = tensor2im(self.input_A)
        fake_B = tensor2im(self.fake_B)
        rec_A = tensor2im(self.rec_A)
        real_B = tensor2im(self.input_B)
        fake_A = tensor2im(self.fake_A)
        rec_B = tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.isTrain and self.p.identity > 0.0:
            ret_visuals['idt_A'] = tensor2im(self.idt_A)
            ret_visuals['idt_B'] = tensor2im(self.idt_B)

        return ret_visuals
示例#8
0
 def inference(self, gpu_id, dataloader, save_dir, latent_size,
               num_lighting_infer, label, visualizer):
     self.to(gpu_id)
     self.load(save_dir, label, visualizer)
     self.rand_G.eval()
     tqdm_data_loader = tqdm(dataloader, desc='infer', leave=False)
     rand_img_dir = os.path.join(save_dir, f'infer_rand')
     os.makedirs(rand_img_dir, exist_ok=True)
     for i, inputs in enumerate(tqdm_data_loader):
         for j in range(num_lighting_infer):
             studio_img = inputs['base'].to(self.device)
             light_vec = torch.randn(
                 (studio_img.shape[0], *latent_size)).to(self.device)
             fake_rand_img = self.rand_G(studio_img, light_vec)
             fake_rand_img = tensor2im(fake_rand_img)
             for k in range(studio_img.shape[0]):
                 fake_k_lighting_j = fake_rand_img[k, :, :]
                 save_folder = os.path.join(rand_img_dir, str(k + 1))
                 os.makedirs(save_folder, exist_ok=True)
                 file_path = os.path.join(save_folder, f'{j + 1}.jpg')
                 io.imsave(file_path, fake_k_lighting_j)
     self.rand_G.train()
示例#9
0
文件: test.py 项目: OkayMing/BPGAN
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = utils.tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        h, w, _ = im.shape
        if aspect_ratio > 1.0:
            im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
        if aspect_ratio < 1.0:
            im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
        utils.save_image(im, save_path)

        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)
示例#10
0
    def display_current_results(self, visuals, epoch, save_result):
        """Display current results on visdom; save current results to an HTML file.

        Parameters:
            visuals (OrderedDict) - - dictionary of images to display or save
            epoch (int) - - the current epoch
            save_result (bool) - - if save the current results to an HTML file
        """
        if self.display_id > 0:  # show images in the browser using visdom
            ncols = self.ncols
            if ncols > 0:  # show all the images in one visdom panel
                ncols = min(ncols, len(visuals))
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)  # create a table css
                # create a table of images.
                title = self.name
                label_html = ''
                label_html_row = ''
                images = []
                idx = 0
                for label, image in visuals.items():
                    image_numpy = utils.tensor2im(image)
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1
                                                                  ])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                try:
                    self.vis.images(images,
                                    nrow=ncols,
                                    win=self.display_id + 1,
                                    padding=2,
                                    opts=dict(title=title + ' images'))
                    label_html = '<table>%s</table>' % label_html
                    self.vis.text(table_css + label_html,
                                  win=self.display_id + 2,
                                  opts=dict(title=title + ' labels'))
                except VisdomExceptionBase:
                    self.create_visdom_connections()

            else:  # show each image in a separate visdom panel;
                idx = 1
                try:
                    for label, image in visuals.items():
                        image_numpy = utils.tensor2im(image)
                        self.vis.image(image_numpy.transpose([2, 0, 1]),
                                       opts=dict(title=label),
                                       win=self.display_id + idx)
                        idx += 1
                except VisdomExceptionBase:
                    self.create_visdom_connections()

        if self.use_html and (
                save_result or not self.saved
        ) and epoch % 1000 == 0:  # save images to an HTML file if they haven't been saved.
            self.saved = True
            # save images to the disk
            for label, image in visuals.items():
                image_numpy = utils.tensor2im(image)
                img_path = os.path.join(self.img_dir,
                                        'epoch%.3d_%s.png' % (epoch, label))
                utils.save_image(image_numpy, img_path)

            # update website
            webpage = htmlu.HTML(self.web_dir,
                                 'Experiment name = %s' % self.name,
                                 refresh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims, txts, links = [], [], []

                for label, image_numpy in visuals.items():
                    image_numpy = utils.tensor2im(image)
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()
示例#11
0
def visualize_sample(model, batch, vocab):
    (imgs, canvases_sel, canvases_ori, objs, boxes, selected_crops,
     original_crops, triples, predicates, obj_to_img, triple_to_img,
     scatter_size_obj, scatter_size_triple) = batch

    samples = []
    # add the ground-truth images
    samples.append(imgs[:1])

    # add the canvases building with original crops
    if canvases_ori is not None:
        samples.append(canvases_ori[:1])

    with torch.no_grad():
        model_out = model(objs,
                          triples,
                          obj_to_img,
                          triple_to_img,
                          boxes_gt=boxes,
                          selected_crops=selected_crops,
                          original_crops=original_crops,
                          scatter_size_obj=scatter_size_obj,
                          scatter_size_triple=scatter_size_triple)

        # add the reconstructed images
        samples.append(model_out[1][:1])

        # add the canvases building with selected crops
        if canvases_sel is not None:
            samples.append(canvases_sel[:1])

        # add the generated images
        samples.append(model_out[0][:1])

        model_out = model(objs,
                          triples,
                          obj_to_img,
                          triple_to_img,
                          boxes_gt=boxes,
                          selected_crops=selected_crops,
                          original_crops=original_crops,
                          scatter_size_obj=scatter_size_obj,
                          scatter_size_triple=scatter_size_triple)
        # add the generated images
        samples.append(model_out[0][:1])

        model_out = model(objs,
                          triples,
                          obj_to_img,
                          triple_to_img,
                          selected_crops=selected_crops,
                          original_crops=original_crops,
                          scatter_size_obj=scatter_size_obj,
                          scatter_size_triple=scatter_size_triple)
        # add the generated images
        samples.append(model_out[0][:1])
    samples = torch.cat(samples, dim=3)
    samples = {
        "samples":
        tensor2im(imagenet_deprocess_batch(samples, rescale=True).squeeze(0))
    }
    # Draw Scene Graphs
    sg_array = draw_scene_graph(objs[obj_to_img == 0],
                                triples[triple_to_img == 0],
                                vocab=vocab)
    samples["scene_graph"] = sg_array
    return samples
示例#12
0
    def display_current_results(self, visuals, epoch):
        """Display current results on visdom; save current results to an HTML file.

        Parameters:
            visuals (OrderedDict) - - dictionary of images to display or save
            epoch (int) - - the current epoch
            save_result (bool) - - if save the current results to an HTML file
        """
        if self.display_id > 0:  # show images in the browser using visdom
            ncols = self.ncols
            if ncols > 0:  # show all the images in one visdom panel
                ncols = min(ncols, len(visuals))
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)  # create a table css
                # create a table of images.
                title = self.name
                label_html = ''
                label_html_row = ''
                images = []
                idx = 0
                for label, image in visuals.items():
                    image_numpy = tensor2im(image)
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1
                                                                  ])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                try:
                    self.vis.images(images,
                                    nrow=ncols,
                                    win=self.display_id + 1,
                                    padding=2,
                                    opts=dict(title=title + ' images'))
                    label_html = '<table>%s</table>' % label_html
                    self.vis.text(table_css + label_html,
                                  win=self.display_id + 2,
                                  opts=dict(title=title + ' labels'))
                except VisdomExceptionBase:
                    self.create_visdom_connections()

            else:  # show each image in a separate visdom panel;
                idx = 1
                try:
                    for label, image in visuals.items():
                        image_numpy = tensor2im(image)
                        self.vis.image(image_numpy.transpose([2, 0, 1]),
                                       opts=dict(title=label),
                                       win=self.display_id + idx)
                        idx += 1
                except VisdomExceptionBase:
                    self.create_visdom_connections()
示例#13
0
 def get_current_visuals(self):
     real_A = tensor2im(self.real_A.data)
     fake_B = tensor2im(self.fake_B.data)
     return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
示例#14
0
 def computeTranslateImage(self, src):
     real_image = self.transform_func(src).unsqueeze(0)
     real_image = real_image.to(self.device)
     with torch.no_grad():
         fake_image, _, _ = self.generator(real_image)
     self.fake_image = tensor2im(fake_image)
示例#15
0
    def test(self, data_fetcher, num_samples, if_baseline=False, if_return_each=False, img_save_folder=None,
             if_train=True):
        """
        val (in training): idx_out=0/1/2/3/4
        test: idx_out=-2, record time wo. iqa
        """
        if if_baseline or if_train:
            assert self.crit_lst is not None, 'NO METRICS!'

        if self.crit_lst is not None:
            if_tar_only = False
            msg = 'dst vs. src | ' if if_baseline else 'tar vs. src | '
        else:
            if_tar_only = True
            msg = 'only get dst | '

        report_dict = None

        recorder_dict = dict()
        for crit_name in self.crit_lst:
            recorder_dict[crit_name] = Recorder()

        write_dict_lst = []
        timer = CUDATimer()

        # validation baseline: no iqa, no parse name
        # validation, not baseline: no iqa, parse name
        # test baseline: no iqa, no parse name
        # test, no baseline, iqa, no parse name
        if_iqa = True if (not if_train) and (not if_baseline) else False
        if if_iqa:
            timer_wo_iqam = Recorder()
            idx_out = -2  # testing; judge by IQAM
        if_parse_name = True if if_train and (not if_baseline) else False

        self.set_eval_mode()

        data_fetcher.reset()
        test_data = data_fetcher.next()
        assert len(test_data['name']) == 1, 'ONLY SUPPORT bs==1!'

        pbar = tqdm(total=num_samples, ncols=100)

        while test_data is not None:
            im_lq = test_data['lq'].cuda(non_blocking=True)  # assume bs=1
            im_name = test_data['name'][0]  # assume bs=1

            if if_parse_name:
                im_type = im_name.split('_')[-1].split('.')[0]
                if im_type in ['qf50', 'qp22']:
                    idx_out = 0
                elif im_type in ['qf40', 'qp27']:
                    idx_out = 1
                elif im_type in ['qf30', 'qp32']:
                    idx_out = 2
                elif im_type in ['qf20', 'qp37']:
                    idx_out = 3
                elif im_type in ['qf10', 'qp42']:
                    idx_out = 4
                else:
                    raise Exception(f"im_type IS {im_type}, NO MATCHING TYPE!")

            timer.start_record()
            if if_tar_only:
                if if_iqa:
                    time_wo_iqa, im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.)
                else:
                    im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.)
                timer.record_inter()
            else:
                im_gt = test_data['gt'].cuda(non_blocking=True)  # assume bs=1
                if if_baseline:
                    im_out = im_lq
                else:
                    if if_iqa:
                        time_wo_iqa, im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out)
                        im_out = im_out.clamp_(0., 1.)
                    else:
                        im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.)
                timer.record_inter()

                _msg = f'{im_name} | '

                for crit_name in self.crit_lst:
                    crit_fn = self.crit_lst[crit_name]['fn']
                    crit_unit = self.crit_lst[crit_name]['unit']

                    perfm = crit_fn(torch.squeeze(im_out, 0), torch.squeeze(im_gt, 0))
                    recorder_dict[crit_name].record(perfm)

                    _msg += f'[{perfm:.3e}] {crit_unit:s} | '

                _msg = _msg[:-3]
                if if_return_each:
                    msg += _msg + '\n'
                pbar.set_description(_msg)

            if if_iqa:
                timer_wo_iqam.record(time_wo_iqa)

            if img_save_folder is not None:  # save im
                im = tensor2im(torch.squeeze(im_out, 0))
                save_path = img_save_folder / (str(im_name) + '.png')
                cv2.imwrite(str(save_path), im)

            pbar.update()
            test_data = data_fetcher.next()
        pbar.close()

        if not if_tar_only:
            for crit_name in self.crit_lst:
                crit_unit = self.crit_lst[crit_name]['unit']
                crit_if_focus = self.crit_lst[crit_name]['if_focus']

                ave_perfm = recorder_dict[crit_name].get_ave()
                msg += f'{crit_name} | [{ave_perfm:.3e}] {crit_unit} | '

                write_dict_lst.append(dict(tag=f'{crit_name} (val)', scalar=ave_perfm))

                if crit_if_focus:
                    report_dict = dict(ave_perfm=ave_perfm, lsb=self.crit_lst[crit_name]['fn'].lsb)

        ave_fps = 1. / timer.get_ave_inter()
        msg += f'ave. fps | [{ave_fps:.1f}]'

        if if_iqa:
            ave_time_wo_iqam = timer_wo_iqam.get_ave()
            fps_wo_iqam = 1. / ave_time_wo_iqam
            msg += f' | ave. fps wo. IQAM | [{fps_wo_iqam:.1f}]'

        if if_train:
            assert report_dict is not None
            return msg.rstrip(), write_dict_lst, report_dict
        else:
            return msg.rstrip()
示例#16
0
def main():
    opt = Options().parse()
    img_names = []
    for name in os.listdir(opt.input):
        if any(
                name.endswith(extension) for extension in [
                    '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
                    '.PPM', '.bmp', '.BMP', '.tiff'
                ]):
            img_names.append(name)

    bilinear_model = BilinearModel(opt.predef_dir)

    if opt.render:
        from renderer import MeshRenderer
        renderer = MeshRenderer()

    if opt.name == 'dpmap_rig':
        opt.input_nc = 6
        pos_maps = np.load(f'{opt.predef_dir}/posmaps.npz')
        pos_maps = pos_maps.f.arr_0
        pos_maps = torch.from_numpy(pos_maps).unsqueeze(0)

    dpmap_model = create_model(opt)

    for img_name in img_names:
        print(f'\nProcessing {img_name}')
        base_name = os.path.splitext(img_name)[0]
        if not os.path.exists(f'{opt.output}/{base_name}'):
            os.mkdir(f'{opt.output}/{base_name}')

        img = cv2.imread(f'{opt.input}/{img_name}')

        print('Fitting 3DMM Parameters...')
        proj_params, verts = bilinear_model.fit_image(img)

        print('Warping texture...')
        verts_img = bilinear_model.project(verts, *proj_params, keepz=False)
        texture = bilinear_model.get_texture(img, verts_img)
        bilinear_model.save_obj(f'{opt.output}/{base_name}/{base_name}.obj',
                                verts,
                                f'./{base_name}.jpg',
                                front=True)
        cv2.imwrite(f'{opt.output}/{base_name}/{base_name}.jpg', texture)

        texture = cv2.resize(texture[600:2500, 1100:3000],
                             (1024, 1024)).astype(np.uint8)

        mask = (255 -
                cv2.imread(f'{opt.predef_dir}/front_mask.png')[:, :,
                                                               0]).astype(bool)
        new_pixels = color_transfer(texture[mask][:, np.newaxis, :])
        texture[mask] = new_pixels[:, 0, :]
        texture = cv2.cvtColor(texture, cv2.COLOR_BGR2RGB).astype(np.float32)
        texture = np.transpose(texture, (2, 0, 1))
        texture = torch.tensor(texture) / 255
        texture = F.normalize(texture, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), True)
        texture = torch.unsqueeze(texture, 0)

        print('Generating displacement maps...')
        dpmap_full = np.zeros((4096, 4096), dtype=np.uint16)
        dpmap_full[...] = 32768
        dpmap_full = Image.fromarray(dpmap_full)
        if opt.name == 'dpmap_rig':
            for i in range(20):
                ipt = torch.cat((texture, pos_maps[:, i * 3:i * 3 + 3]), dim=1)
                dpmap = dpmap_model.inference(ipt, torch.tensor(0))
                dpmap = tensor2im(dpmap.detach()[0], size=(1900, 1900))
                dpmap = Image.fromarray(dpmap)
                dpmap_full.paste(dpmap, (1100, 600, 3000, 2500))
                dpmap_full.save(
                    f'{opt.output}/{base_name}/{base_name}_dpmap_{str(i)}.png')
        else:
            dpmap = dpmap_model.inference(texture, torch.tensor(0))
            dpmap = tensor2im(dpmap.detach()[0], size=(1900, 1900))
            dpmap = Image.fromarray(dpmap)
            dpmap_full.paste(dpmap, (1100, 600, 3000, 2500))
            dpmap_full.save(f'{opt.output}/{base_name}/{base_name}_dpmap.png')
            if opt.render:
                print('Rendering results...')
                front_verts = verts[bilinear_model.front_verts_indices]
                tris, vert_texcoords = bilinear_model.tris.copy(
                ), bilinear_model.vert_texcoords.copy()
                for _ in range(3):
                    front_verts, tris, vert_texcoords = subdiv(
                        front_verts, tris, vert_texcoords)
                front_verts = dpmap2verts(front_verts, tris, vert_texcoords,
                                          dpmap_full)

                verts_img = bilinear_model.project(front_verts,
                                                   *proj_params,
                                                   keepz=True)
                renderer.render(
                    verts_img, tris, (img.shape[1], img.shape[0]),
                    f'{opt.input}/{img_name}',
                    f'{opt.output}/{base_name}/{base_name}_render.jpg')
示例#17
0
    def evaluate(self,
                 dataloader,
                 save_dir,
                 phase='test',
                 save_result=False,
                 eval_step=-1):
        self.rand_G.eval()
        self.studio_G.eval()
        psnr_studio = 0
        ssim_studio = 0
        psnr_rand = 0
        ssim_rand = 0
        tqdm_data_loader = tqdm(dataloader, desc=phase, leave=False)
        idx = 0
        if save_result:
            studio_img_dir = os.path.join(save_dir, f'{phase}_studio')
            rand_img_dir = os.path.join(save_dir, f'{phase}_rand')
            os.makedirs(studio_img_dir, exist_ok=True)
            os.makedirs(rand_img_dir, exist_ok=True)
        for i, inputs in enumerate(tqdm_data_loader):
            rand_img = inputs['rand_lc'].to(self.device)
            studio_img = inputs['base'].to(self.device)

            fake_studio_img, light_vec_forward = self.studio_G(rand_img)

            fake_rand_img = self.rand_G(studio_img, light_vec_forward)
            crop_size = 10
            fake_studio = tensor2im(fake_studio_img)
            fake_rand = tensor2im(fake_rand_img)
            rand = tensor2im(rand_img)
            studio = tensor2im(studio_img)
            fake_studio = fake_studio[:, crop_size:-crop_size,
                                      crop_size:-crop_size]
            fake_rand = fake_rand[:, crop_size:-crop_size,
                                  crop_size:-crop_size]
            gt_studio = studio[:, crop_size:-crop_size, crop_size:-crop_size]
            gt_rand = rand[:, crop_size:-crop_size, crop_size:-crop_size]
            for j in range(rand_img.shape[0]):
                gt_rand_j = gt_rand[j, :, :]
                gt_studio_j = gt_studio[j, :, :]
                fake_rand_j = fake_rand[j, :, :]
                fake_studio_j = fake_studio[j, :, :]
                if save_result:

                    def save_result(path, gt, fake):
                        gt_dir = os.path.join(path, 'gt')
                        fake_dir = os.path.join(path, 'fake')
                        os.makedirs(gt_dir, exist_ok=True)
                        os.makedirs(fake_dir, exist_ok=True)
                        gt_file = os.path.join(gt_dir, f'{idx + 1}.jpg')
                        io.imsave(gt_file, gt)
                        fake_file = os.path.join(fake_dir, f'{idx + 1}.jpg')
                        io.imsave(fake_file, fake)

                    save_result(studio_img_dir, gt_studio_j, fake_studio_j)
                    save_result(rand_img_dir, gt_rand_j, fake_rand_j)
                psnr_studio += calculate_psnr(gt_studio_j, fake_studio_j)
                psnr_rand += calculate_psnr(gt_rand_j, fake_rand_j)
                ssim_studio += structural_similarity(gt_studio_j,
                                                     fake_studio_j,
                                                     data_range=255,
                                                     multichannel=False,
                                                     gaussian_weights=True,
                                                     K1=0.01,
                                                     K2=0.03)
                ssim_rand += structural_similarity(gt_rand_j,
                                                   fake_rand_j,
                                                   data_range=255,
                                                   multichannel=False,
                                                   gaussian_weights=True,
                                                   K1=0.01,
                                                   K2=0.03)
                idx += 1
            if eval_step != -1 and (i + 1) % eval_step == 0:
                break

        self.rand_G.train()
        self.studio_G.train()

        return {
            'psnr_rand': psnr_rand / idx,
            'ssim_rand': ssim_rand / idx,
            'psnr_studio': psnr_studio / idx,
            'ssim_studio': ssim_studio / idx
        }
示例#18
0
    dim = [256, 256]
    image = image.resize(dim, Image.BICUBIC)
    mask = mask.resize(dim, Image.BICUBIC)
    target = target.resize(dim, Image.BICUBIC)
    image = transforms.ToTensor()(image)
    mask = transforms.ToTensor()(mask)
    target = transforms.ToTensor()(target)

    return image, mask, target


img_ids = get_test_id(args.test_list_path)

with open("metrics.txt", "w") as f:
    for img_id in tqdm(img_ids):
        image, mask, target = get_test_data(args.img_path, args.mask_path,
                                            args.target_path, img_id)
        with torch.no_grad():
            image = image.to(device).unsqueeze(0)
            mask = mask.to(device).unsqueeze(0)
            target = target.to(device).unsqueeze(0)
            output = model(image, mask)
            output = utils.tensor2im(output, imtype=np.float32)
            target = utils.tensor2im(target, imtype=np.float32)
            mse_score_op = mse(output, target)
            psnr_score_op = psnr(target,
                                 output,
                                 data_range=output.max() - output.min())
            f.write('ID:{}, MSE:{}, PSNR:{}\n'.format(img_id, mse_score_op,
                                                      psnr_score_op))
f.close()