Пример #1
0
def main():
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break

        a_parsing_tensor = data['a_parsing_tensor']
        b_parsing_tensor = data['b_parsing_tensor']
        b_label_tensor = data['b_label_tensor']
        b_label_show_tensor = data['b_label_show_tensor']
        a_jpg_path = data['a_img_path']
        b_jpg_path = data['b_img_path']
        pdb.set_trace()
        input_tensor = torch.cat(
            (a_parsing_tensor, b_parsing_tensor, b_label_tensor), dim=1)
        input_var = Variable(input_tensor.type(torch.cuda.FloatTensor))

        fake_b_parsing = model.inference(input_var)
        test_list = [
            ('b_label_show', util.tensor2im(b_label_show_tensor[0])),
            ('a_parsing_tensor_RGB',
             util.parsing2im(
                 util.label_2_onhot(a_parsing_tensor[0],
                                    parsing_label_nc=opt.parsing_label_nc))),
            ('fake_b_parsing_RGB', util.parsing2im(fake_b_parsing.data[0])),
            ('b_parsing_tensor_RGB',
             util.parsing2im(
                 util.label_2_onhot(b_parsing_tensor[0],
                                    parsing_label_nc=opt.parsing_label_nc))),
            ('fake_b_parsing',
             util.parsing_2_onechannel(fake_b_parsing.data[0])),
        ]

        print fake_b_parsing.shape
        ### save image
        visuals = OrderedDict(test_list)
        visualizer.save_images_parsing_label(webpage, visuals, a_jpg_path[0],
                                             b_jpg_path[0])

        print('[%s]process image... %s' % (i, a_jpg_path[0]))
        ### 从零开始为啥只有12779张?本来12800的!难道有11pair是重复的?检查pair文件。。
        ### 奇怪哦!难道要12800 + 21

    webpage.save()

    image_dir = webpage.get_image_dir()
    print image_dir
Пример #2
0
def get_valList(model, opt):
    val_list = []
    lines = open("./datasets/deepfashion/paper_images/256/val_img_path.txt"
                 ).readlines()
    for i in range(len(lines)):
        image_a_path, image_b_path = lines[i].split()[0], lines[i].split()[1]
        a_jpg_path = os.path.join(opt.dataroot, image_a_path)
        b_jpg_path = os.path.join(opt.dataroot, image_b_path)
        a_parsing_tensor, b_parsing_tensor, fake_b_parsing, b_label_show_tensor = get_test_result(
            a_jpg_path, b_jpg_path, model, opt)

        val_list.append(
            ('val_b_label_show_{}'.format(i), tensor2im(b_label_show_tensor)))
        val_list.append(('val_a_parsing_{}'.format(i),
                         parsing2im(label_2_onhot(a_parsing_tensor))))
        val_list.append(('val_b_parsing_{}'.format(i),
                         parsing2im(label_2_onhot(b_parsing_tensor))))
        val_list.append(('val_fake_b_parsing_{}'.format(i),
                         parsing2im(fake_b_parsing.data[0])))

    return val_list
Пример #3
0
def generate_fake_B(a_image_tensor, b_image_tensor, b_label_tensor,
                    a_parsing_tensor):
    ##### stage I #################
    input_1_tensor = torch.cat([a_parsing_tensor, b_label_tensor], dim=1)
    input_1_var = Variable(input_1_tensor.type(torch.cuda.FloatTensor))
    model_1.eval()
    fake_b_parsing = model_1.inference_2(input_1_var)

    a_parsing_tensor_RGB_numpy = util.parsing2im(
        util.label_2_onhot(a_parsing_tensor[0],
                           parsing_label_nc=opt.parsing_label_nc))
    fake_b_parsing_RGB_numpy = util.parsing2im(fake_b_parsing.data[0])
    fake_b_parsing_label = util.parsing_2_onechannel(fake_b_parsing.data[0])

    a_parsing_RGB_path = './a_parsing_RGB.png'
    fake_b_parsing_RGB_path = './fake_b_parsing_RGB.png'
    fake_b_parsing_label_path = './fake_b_parsing_label.png'
    util.save_image(a_parsing_tensor_RGB_numpy, a_parsing_RGB_path)
    util.save_image(fake_b_parsing_RGB_numpy, fake_b_parsing_RGB_path)
    cv.imwrite(fake_b_parsing_label_path, fake_b_parsing_label)

    ##### GEO ######################

    theta_json = generate_theta(a_parsing_RGB_path, fake_b_parsing_RGB_path,
                                geo)
    theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor = get_thetas_affgrid_tensor_by_json(
        affTnf, tpsTnf, theta_json)
    theta_aff_tensor = theta_aff_tensor.unsqueeze_(0)
    theta_tps_tensor = theta_tps_tensor.unsqueeze_(0)
    theta_aff_tps_tensor = theta_aff_tps_tensor.unsqueeze_(0)

    #### stage II #################
    fake_b_parsing_label_tensor = get_parsing_label_tensor(
        fake_b_parsing_label_path, opt)
    fake_b_parsing_label_tensor = fake_b_parsing_label_tensor.unsqueeze_(0)

    input_2_tensor = torch.cat([a_image_tensor, b_image_tensor, b_label_tensor, a_parsing_tensor, fake_b_parsing_label_tensor, \
         theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor], dim=1)
    input_2_var = Variable(input_2_tensor.type(torch.cuda.FloatTensor))
    model_2.eval()
    fake_b = model_2.inference(input_2_var)

    return fake_b, fake_b_parsing_label_tensor
Пример #4
0
            errors = {
                k: v.item() if not isinstance(v, int) else v
                for k, v in loss_dict.items()
            }
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ############## Display output images and Val images ######################
        if save_fake:
            val_list = get_valList(model, opt)
            train_list = [
                ('b_label', tensor2im(b_label_show_tensor[0])),
                ('a_parsing',
                 parsing2im(
                     label_2_onhot(a_parsing_tensor[0],
                                   parsing_label_nc=opt.parsing_label_nc))),
                ('b_parsing',
                 parsing2im(
                     label_2_onhot(b_parsing_tensor[0],
                                   parsing_label_nc=opt.parsing_label_nc))),
                ('fake_b_parsing', parsing2im(fake_b_parsing.data[0]))
            ]
            val_list[0:0] = train_list
            # val_list = train_list
            visuals = OrderedDict(val_list)

            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
Пример #5
0
def main():
    web_dir = os.path.join(
        opt.results_dir, opt.name,
        '%s_%s_%s' % (opt.phase, opt.which_epoch, "I_and_II"))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break

        a_image_tensor = data['a_image_tensor']  # 3
        b_image_tensor = data['b_image_tensor']  # 3
        b_label_tensor = data['b_label_tensor']  # 18
        a_parsing_tensor = data['a_parsing_tensor']  # 1
        b_label_show_tensor = data['b_label_show_tensor']
        a_jpg_path = data['a_jpg_path']
        b_jpg_path = data['b_jpg_path']

        ##### stage I #################
        input_1_tensor = torch.cat([a_parsing_tensor, b_label_tensor], dim=1)
        input_1_var = Variable(input_1_tensor.type(torch.cuda.FloatTensor))
        model_1.eval()
        fake_b_parsing = model_1.inference_2(input_1_var)

        a_parsing_tensor_RGB_numpy = util.parsing2im(
            util.label_2_onhot(a_parsing_tensor[0],
                               parsing_label_nc=opt.parsing_label_nc))
        fake_b_parsing_RGB_numpy = util.parsing2im(fake_b_parsing.data[0])
        fake_b_parsing_label = util.parsing_2_onechannel(
            fake_b_parsing.data[0])

        a_parsing_RGB_path = './a_parsing_RGB.png'
        fake_b_parsing_RGB_path = './fake_b_parsing_RGB.png'
        fake_b_parsing_label_path = './fake_b_parsing_label.png'
        util.save_image(a_parsing_tensor_RGB_numpy, a_parsing_RGB_path)
        util.save_image(fake_b_parsing_RGB_numpy, fake_b_parsing_RGB_path)
        cv.imwrite(fake_b_parsing_label_path, fake_b_parsing_label)

        ##### GEO ######################

        theta_json = generate_theta(a_parsing_RGB_path,
                                    fake_b_parsing_RGB_path, geo)
        theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor = get_thetas_affgrid_tensor_by_json(
            affTnf, tpsTnf, theta_json)
        theta_aff_tensor = theta_aff_tensor.unsqueeze_(0)
        theta_tps_tensor = theta_tps_tensor.unsqueeze_(0)
        theta_aff_tps_tensor = theta_aff_tps_tensor.unsqueeze_(0)

        #### stage II #################
        fake_b_parsing_label_tensor = get_parsing_label_tensor(
            fake_b_parsing_label_path, opt)
        fake_b_parsing_label_tensor = fake_b_parsing_label_tensor.unsqueeze_(0)

        input_2_tensor = torch.cat([a_image_tensor, b_image_tensor, b_label_tensor, a_parsing_tensor, fake_b_parsing_label_tensor, \
                                    theta_aff_tensor, theta_tps_tensor, theta_aff_tps_tensor], dim=1)
        input_2_var = Variable(input_2_tensor.type(torch.cuda.FloatTensor))
        model_2.eval()
        fake_b = model_2.inference(input_2_var)

        a_parsing_rgb_tensor = parsingim_2_tensor(
            a_parsing_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)
        b_parsing_rgb_tensor = parsingim_2_tensor(
            fake_b_parsing_label_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)

        show_image_tensor_1 = torch.cat(
            (a_image_tensor, b_label_show_tensor, b_image_tensor), dim=3)
        show_image_tensor_2 = torch.cat(
            (a_parsing_rgb_tensor, b_parsing_rgb_tensor,
             fake_b.data[0:1, :, :, :].cpu()),
            dim=3)
        show_image_tensor = torch.cat(
            (show_image_tensor_1[0:1, :, :, :], show_image_tensor_2), dim=2)
        test_list = [('a-b-fake_b', tensor2im(show_image_tensor[0])),
                     ('fake_image', util.tensor2im(fake_b.data[0])),
                     ('b_image', util.tensor2im(b_image_tensor[0]))]

        ### save image
        visuals = OrderedDict(test_list)
        visualizer.save_images(webpage, visuals, a_jpg_path[0], b_jpg_path[0])

        if i % 1 == 0:
            print('[%s]process image... %s' % (i, a_jpg_path[0]))

    webpage.save()

    image_dir = webpage.get_image_dir()
    print image_dir
Пример #6
0
            model.module.optimizer_D.step()

        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == 0:
            errors = {k: v.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ############## Display output images and Val images ######################
        if 0:
            # valifation for adding skeleton not implemented
            val_list = get_valList(model, opt)
            train_list = [('b_label', tensor2im(b_label_show_tensor[0])),
                           ('a_parsing', parsing2im(label_2_onhot(a_parsing_tensor[0], parsing_label_nc=opt.parsing_label_nc))),
                           ('b_parsing', parsing2im(label_2_onhot(b_parsing_tensor[0], parsing_label_nc=opt.parsing_label_nc))),
                           ('fake_b_parsing', parsing2im(fake_b_parsing.data[0]))]
            val_list[0:0] = train_list
            visuals = OrderedDict(val_list)

            visualizer.display_current_results(visuals, epoch, total_steps)


        ### save latest model
        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
       
    # end of epoch