Exemplo n.º 1
0
if __name__ == '__main__':

    preprocesser = Preprocesser()
    args = preprocesser.get_arguments()

    dataset = ABDataset(args)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

    model = DivCycleGAN(args)
    model.load()

    visualizer = Visualizer(args)

    webDirectory = os.path.join("./result", args.taskname, "sample")
    webpage = html.HTML(webDirectory,
                        "Experiment = %s, Phase = test" % args.taskname)

    for i, data in enumerate(loader):
        multiDict = OrderedDict()
        for time in range(30):
            randomVector = random_generator(args.numberclasses)
            model.forward(data["A"], data["B"], data["C"], randomVector)
            visuals = model.get_current_visuals()
            multiDict["real_A"] = visuals["real_A"]
            multiDict["fake_B_%d" % time] = visuals["fake_B2"]

        visualizer.save_images(webpage, multiDict, data["Name"][0])

    webpage.save()
Exemplo n.º 2
0
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip
opt.loadSize = opt.fineSize  # Do not scale!

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
# for guidance loss, we need to presetting the target of innerCos
model.preset_innerCos()

visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name,
                       '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(
    web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
    (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    t1 = time.time()
    model.set_input(data)
    model.test()
    t2 = time.time()
    print(t2 - t1)
    visuals = model.get_current_visuals()
    img_path = model.get_image_paths()
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)
Exemplo n.º 3
0
opt = TestOptions().parse()
opt.nThreads = 1  # test code only supports nThreads=1
opt.batchSize = 1  # test code only supports batchSize=1
opt.serial_batches = True  # no shuffle
# create dataset
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
model.eval()
print('Loading model %s' % opt.model)

# create website
web_dir = os.path.join(opt.results_dir,
                       opt.phase + '_sync' if opt.sync else opt.phase)
webpage = html.HTML(
    web_dir, 'Training = %s, Phase = %s, G = %s, E = %s' %
    (opt.name, opt.phase, opt.G_path, opt.E_path))

# sample random z
if opt.sync:
    z_samples = get_random_z(opt)

# test stage
for i, data in enumerate(islice(dataset, opt.how_many)):
    model.set_input(data)
    print('process input image %3.3d/%3.3d' % (i, opt.how_many))
    if not opt.sync:
        z_samples = get_random_z(opt)
    for nn in range(opt.n_samples + 1):
        encode_B = nn == 0 and not opt.no_encode
        _, real_A, fake_B, real_B, _ = model.test_simple(
joint_inference_model.opt_maskgen.load_image = 1
joint_inference_model.opt_maskgen.min_box_size = 128
joint_inference_model.opt_maskgen.max_box_size = -1 # not actually used

opt_maskgen = joint_inference_model.opt_maskgen
opt_pix2pix = joint_inference_model.opt_imggen

# Load data
data_loader = SegmentationDataset()
data_loader.initialize(opt_maskgen)
visualizer = Visualizer(opt_maskgen)
# create website
base_name = joint_opt.base_name #'giraffe_run'
web_dir = os.path.join('./results', base_name, 'val')

webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s' %
                   ('Joint Inference', 'val'))

# Save directory
if not os.path.exists('./results'):
    os.makedirs('./results')
if not os.path.exists('./results/test_joint_inference'):
    os.makedirs('./results/test_joint_inference')
save_dir = './results/test_joint_inference/'
print(data_loader.dataset_size)
for i in range(data_loader.dataset_size):
    if i >= joint_opt.how_many:
        break
    try:
        # Get data
        raw_inputs, inst_info = data_loader.get_raw_inputs(i)
        img_orig, label_orig = joint_inference_model.normalize_input( \
Exemplo n.º 5
0
#########
model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0

#########
import matlab.engine  # matlab 2017b is needed for python 3.6!!!
eng = matlab.engine.start_matlab()

# create website
import os
from util import html
test_opt.results_dir = './results/'
web_dir = os.path.join(test_opt.results_dir, test_opt.name)
webpage = html.HTML(
    web_dir, 'Experiment = %s, Phase = %s' % (test_opt.name, test_opt.phase))

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch)
Exemplo n.º 6
0
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        # if epoch != 0 and epoch % 20 == 0:
        if epoch == 20 or epoch == 40 or epoch == 60:
            web_dir = os.path.join('./results/', opt.name,
                                   '%s_%s' % ('attn', str(epoch)))
            webpage = html.HTML(
                web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
                (opt.name, 'attn', str(epoch)))
            for i, data in enumerate(dataset):  # inner loop within one epoch
                iter_start_time = time.time(
                )  # timer for computation per iteration
                if total_iters % opt.print_freq == 0:
                    t_data = iter_start_time - iter_data_time
                visualizer.reset()
                total_iters += opt.batch_size
                epoch_iter += opt.batch_size
                model.set_input(
                    data)  # unpack data from dataset and apply preprocessing
                model.optimize_parameters(
                )  # calculate loss functions, get gradients, update network weights

                visuals = model.get_current_visuals()
opt.batch_size = 1  # test code only supports batch_size=1
opt.serial_batches = True  # no shuffle

# create dataset
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
model.setup(opt)
model.eval()
print('Loading model %s' % opt.model)

# create website
web_dir = os.path.join(opt.results_dir,
                       opt.phase + '_sync' if opt.sync else opt.phase)
webpage = html.HTML(
    web_dir,
    'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))

# sample random z
if opt.sync:
    z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)

# test stage
for i, data in enumerate(islice(dataset, opt.num_test)):
    model.set_input(data)
    print('process input image %3.3d/%3.3d' % (i, opt.num_test))
    if not opt.sync:
        z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)
    for nn in range(opt.n_samples + 1):
        encode = nn == 0 and not opt.no_encode
        real_A, fake_B, real_B = model.test(z_samples[[nn]], encode=encode)
Exemplo n.º 8
0
def main():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith(
                "onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx,
                              verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch,
                                       [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                                 [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'],
                                        data['image'])

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()
Exemplo n.º 9
0
def test(opt):
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#test batches = %d' % (int(dataset_size / len(opt.sort_order))))
    visualizer = Visualizer(opt)
    model = create_model(opt)
    model.eval()

    # create webpage
    if opt.random_seed != -1:
        exp_dir = '%s_%s_seed%s' % (opt.phase, opt.which_epoch,
                                    str(opt.random_seed))
    else:
        exp_dir = '%s_%s' % (opt.phase, opt.which_epoch)
    web_dir = os.path.join(opt.results_dir, opt.name, exp_dir)

    if opt.traverse or opt.deploy:
        if opt.traverse:
            out_dirname = 'traversal'
        else:
            out_dirname = 'deploy'
        output_dir = os.path.join(web_dir, out_dirname)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        for image_path in opt.image_path_list:
            print(image_path)
            data = dataset.dataset.get_item_from_path(image_path)
            visuals = model.inference(data)
            if opt.traverse and opt.make_video:
                out_path = os.path.join(
                    output_dir,
                    os.path.splitext(os.path.basename(image_path))[0] + '.mp4')
                visualizer.make_video(visuals, out_path)
            elif opt.traverse or (opt.deploy and opt.full_progression):
                if opt.traverse and opt.compare_to_trained_outputs:
                    out_path = os.path.join(
                        output_dir,
                        os.path.splitext(os.path.basename(image_path))[0] +
                        '_compare_to_{}_jump_{}.png'.format(
                            opt.compare_to_trained_class,
                            opt.trained_class_jump))
                else:
                    out_path = os.path.join(
                        output_dir,
                        os.path.splitext(os.path.basename(image_path))[0] +
                        '.png')
                visualizer.save_row_image(visuals,
                                          out_path,
                                          traverse=opt.traverse)
            else:
                out_path = os.path.join(output_dir,
                                        os.path.basename(image_path[:-4]))
                visualizer.save_images_deploy(visuals, out_path)
    else:
        webpage = html.HTML(
            web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
            (opt.name, opt.phase, opt.which_epoch))

        # test
        for i, data in enumerate(dataset):
            if i >= opt.how_many:
                break

            visuals = model.inference(data)
            img_path = data['Paths']
            rem_ind = []
            for i, path in enumerate(img_path):
                if path != '':
                    print('process image... %s' % path)
                else:
                    rem_ind += [i]

            for ind in reversed(rem_ind):
                del img_path[ind]

            visualizer.save_images(webpage, visuals, img_path)

            webpage.save()
Exemplo n.º 10
0
    opt_val.phase = 'val'
    opt_val.isTrain = False  # get validating options
    opt_val.results_dir = './results/'
    opt_val.aspect_ratio = 1.0
    # opt_val.sample_nums = 100
    # opt_val.dataset_mode = "alignedm2md"

    dataset_val = create_dataset(opt_val)
    #     opt_val.print_options(opt_val)
    print(opt_val)
    web_dir_val = os.path.join(
        opt_val.results_dir, opt_val.name,
        '{}_{}'.format(opt_val.phase,
                       opt_val.epoch))  # define the website directory
    webpage_val = html.HTML(
        web_dir_val, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt_val.name, opt_val.phase, opt_val.epoch))

    min_mae = 100.
    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset(
        )  # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
Exemplo n.º 11
0
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
TrainOptions().printandsave(opt)
print('#training images = %d' % dataset_size)
print('#len of dataset = %d' % len(dataset))

model = create_model(opt)
visualizer = Visualizer(opt)
visualizer.set_x(opt, len(dataset))
total_steps = 0

opt.results_dir = os.path.join(os.path.dirname(opt.checkpoints_dir), 'results')
web_dir = os.path.join(opt.results_dir, opt.name)
webpage = html.HTML(web_dir, 'Experiment = %s' % (opt.name))

for epoch in tqdm(range(opt.epoch_count, opt.niter + opt.niter_decay + 1)):

    for i, data in enumerate(dataset):
        if epoch == opt.epoch_count and i == 0:
            save_data = data
        else:
            save_data = save_data

        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            visualizer.print_current_errors(epoch, i, errors)
Exemplo n.º 12
0
def train_target(param_dir, if_train, loadSize, continue_train = False):

    # import pix2pixHD module
    pix2pixhd_dir = Path('../src/pix2pixHD/')  # should download github code 'pix2pixHD'

    import sys
    sys.path.append(str(pix2pixhd_dir))

    #get_ipython().run_line_magic('load_ext', 'autoreload')
    #get_ipython().run_line_magic('autoreload', '2')

    from options.train_options import TrainOptions
    from data.data_loader import CreateDataLoader
    from models.models import create_model
    import util.util as util
    from util.visualizer import Visualizer
    from util import html

    with open(param_dir, mode='rb') as f:
        opt = pickle.load(f)
        opt.loadSize = loadSize
        opt.fineSize = loadSize
        #opt.fine_size = loadSize
        #opt.loadsize = loadSize
        opt.continue_train = continue_train
        print('opt parameters: ', opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    
    model = create_model(opt)
    visualizer = Visualizer(opt)
        
    if if_train:
    
        print('#training images = %d' % dataset_size)
        if opt.continue_train:
            print ('Resumming training ...')
        else:
            print ('Starting new training ...')

        start_epoch, epoch_iter = 1, 0
        total_steps = (start_epoch-1) * dataset_size + epoch_iter
        display_delta = total_steps % opt.display_freq
        print_delta = total_steps % opt.print_freq
        save_delta = total_steps % opt.save_latest_freq
        
        for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):  # (1,20+20+1)
            epoch_start_time = time.time()
            if epoch != start_epoch:
                epoch_iter = epoch_iter % dataset_size
            for i, data in enumerate(dataset, start=epoch_iter):
                iter_start_time = time.time()
                total_steps += opt.batchSize
                epoch_iter += opt.batchSize

                # whether to collect output images
                save_fake = total_steps % opt.display_freq == display_delta
                
                ############## Forward Pass ######################
                losses, generated = model(Variable(data['label']), Variable(data['inst']), 
                    Variable(data['image']), Variable(data['feat']), infer=save_fake)
                
                # sum per device losses
                losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
                loss_dict = dict(zip(model.module.loss_names, losses))

                # calculate final loss scalar
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)
                
                ############### Backward Pass ####################
                # update generator weights
                model.module.optimizer_G.zero_grad()
                loss_G.backward()
                model.module.optimizer_G.step()

                # update discriminator weights
                model.module.optimizer_D.zero_grad()
                loss_D.backward()
                model.module.optimizer_D.step()
                
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

                ############## Display results and errors ##########
                ### print out errors
                if total_steps % opt.print_freq == print_delta:
                    errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                    t = (time.time() - iter_start_time) / opt.batchSize
                    visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                    visualizer.plot_current_errors(errors, total_steps)

                ### display output images
                if save_fake:
                    visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                           ('synthesized_image', util.tensor2im(generated.data[0])),
                                           ('real_image', util.tensor2im(data['image'][0]))])
                    visualizer.display_current_results(visuals, epoch, total_steps)

                ### save latest model
                if total_steps % opt.save_latest_freq == save_delta:
                    print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                    model.module.save('latest')            
                    np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

                if epoch_iter >= dataset_size:
                    break
               
            # end of epoch 
            iter_end_time = time.time()
            print('End of epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

            ### save model for this epoch
            if epoch % opt.save_epoch_freq == 0:  # opt.save_epoch_freq == 10
                print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        
                model.module.save('latest')
                model.module.save(epoch)
                np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')

            ### instead of only training the local enhancer, train the entire network after certain iterations
            if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
                model.module.update_fixed_params()

            ### linearly decay learning rate after certain iterations
            if epoch > opt.niter:
                model.module.update_learning_rate()
                
        torch.cuda.empty_cache()
    else:
        
        # create website
        web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
        webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

        for data in tqdm(dataset):
            minibatch = 1 
            generated = model.inference(data['label'], data['inst'])
                
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0]))])
            img_path = data['path']
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()
        torch.cuda.empty_cache()
def prepare_face_enhancer_data(target_dir, run_name):

    face_sync_dir = os.path.join(target_dir, 'face')
    os.makedirs(face_sync_dir, exist_ok=True)
    test_sync_dir = os.path.join(face_sync_dir, 'test_sync')
    os.makedirs(test_sync_dir, exist_ok=True)
    test_real_dir = os.path.join(face_sync_dir, 'test_real')
    os.makedirs(test_real_dir, exist_ok=True)
    test_img = os.path.join(target_dir, 'test_img')
    os.makedirs(test_img, exist_ok=True)
    test_label = os.path.join(target_dir, 'test_label')
    os.makedirs(test_label, exist_ok=True)

    transfer_face_sync_dir = os.path.join(target_dir, 'face_transfer')
    os.makedirs(transfer_face_sync_dir, exist_ok=True)
    transfer_test_sync_dir = os.path.join(transfer_face_sync_dir, 'test_sync')
    os.makedirs(transfer_test_sync_dir, exist_ok=True)
    transfer_test_real_dir = os.path.join(transfer_face_sync_dir, 'test_real')
    os.makedirs(transfer_test_real_dir, exist_ok=True)

    train_dir = os.path.join(target_dir, 'train', 'train_img')
    label_dir = os.path.join(target_dir, 'train', 'train_label')

    print('Prepare test_real....')
    for img_file in tqdm(sorted(os.listdir(train_dir))):
        img_idx = int(img_file.split('.')[0])
        img = cv2.imread(os.path.join(train_dir, '{:05}.png'.format(img_idx)))
        label = cv2.imread(os.path.join(label_dir,
                                        '{:05}.png'.format(img_idx)))
        cv2.imwrite(os.path.join(test_real_dir, '{:05}.png'.format(img_idx)),
                    img)
        cv2.imwrite(
            os.path.join(transfer_test_real_dir, '{:05}.png'.format(img_idx)),
            img)
        cv2.imwrite(os.path.join(test_img, '{:05}.png'.format(img_idx)), img)
        cv2.imwrite(os.path.join(test_label, '{:05}.png'.format(img_idx)),
                    label)

    print('Prepare test_sync....')

    import src.config.test_opt as opt
    if device == torch.device('cpu'):
        opt.gpu_ids = []
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    opt.checkpoints_dir = os.path.join(dir_name, '../../checkpoints/')
    opt.dataroot = target_dir
    opt.name = run_name
    opt.nThreads = 0
    opt.results_dir = os.path.join(dir_name, '../../face_enhancer_results/')

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)

    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    model = create_model(opt)

    for data in tqdm(dataset):
        minibatch = 1
        generated = model.inference(data['label'], data['inst'])

        visuals = OrderedDict([('synthesized_image',
                                util.tensor2im(generated.data[0]))])
        img_path = data['path']
        visualizer.save_images(webpage, visuals, img_path)
    webpage.save()
    torch.cuda.empty_cache()

    print(f'Copy the synthesized images in {test_sync_dir}...')
    synthesized_image_dir = os.path.join(dir_name,
                                         '../../face_enhancer_results',
                                         run_name, 'test_latest/images/')
    img_list = [
        f for f in os.listdir(synthesized_image_dir)
        if f.endswith('synthesized_image.jpg')
    ]
    for img_file in tqdm(sorted(img_list)):
        img_idx = int(img_file.split('_')[0])
        img = cv2.imread(
            os.path.join(synthesized_image_dir,
                         '{:05}_synthesized_image.jpg'.format(img_idx)))
        cv2.imwrite(os.path.join(test_sync_dir, '{:05}.png'.format(img_idx)),
                    img)

    print('Copy transfer_test_sync')
    previous_run_img_dir = os.path.join(dir_name, '../../results', run_name,
                                        'test_latest/images/')
    img_list = [
        f for f in os.listdir(previous_run_img_dir)
        if f.endswith('synthesized_image.jpg')
    ]
    for img_file in tqdm(sorted(img_list)):
        img_idx = int(img_file.split('_')[0])
        img = cv2.imread(
            os.path.join(previous_run_img_dir,
                         '{:05}_synthesized_image.jpg'.format(img_idx)))
        cv2.imwrite(
            os.path.join(transfer_test_sync_dir, '{:05}.png'.format(img_idx)),
            img)
Exemplo n.º 14
0
    plt.savefig(out_path + img_name + suffix + '_heatmap_BArec.png')
    plt.close()
    ims = [img_name + suffix + '_real_A.png', img_name + suffix + '_fake_B.png', img_name + suffix + '_rec_A.png', 
           img_name + suffix + '_heatmap_AB.png', img_name + suffix + '_heatmap_ABrec.png',
           img_name + suffix + '_real_B.png', img_name + suffix + '_fake_A.png', img_name + suffix + '_rec_B.png', 
           img_name + suffix + '_heatmap_BA.png', img_name + suffix + '_heatmap_BArec.png']
    txts = [suffix + '_real_A', suffix + '_fake_B', suffix + '_rec_A', 
            suffix + '_heatmap_AB', suffix + '_heatmap_ABrec',
            suffix + '_real_B', suffix + '_fake_A', suffix + '_rec_B', 
            suffix + '_heatmap_BA', suffix + '_heatmap_BArec']
    links = [img_name + suffix + '_real_A.png', img_name + suffix + '_fake_B.png', img_name + suffix + '_rec_A.png',
             img_name + suffix + '_heatmap_AB.png', img_name + suffix + '_heatmap_ABrec.png',
             img_name + suffix + '_real_B.png', img_name + suffix + '_fake_A.png', img_name + suffix + '_rec_B.png', 
             img_name + suffix + '_heatmap_BA.png', img_name + suffix + '_heatmap_BArec.png']
    return ims, txts, links

web_dir = out_root
webpage = html.HTML(web_dir, 'analysis_heatmap')
for pre in ['HG', 'LG']:
    for idx in range(1,25+1):
        img_name = '%s_%04d_090' %(pre, idx)
        print img_name
        webpage.add_header(img_name + '_groundtruth')
        ims, txts, links = draw_webpage_gt(img_name, out_path)
        webpage.add_images(ims, txts, links, width=256)
        ims, txts, links = draw_webpage_cyclegan(img_name, cyclegan_path, out_path, '_cyclegan')
        webpage.add_images(ims, txts, links, width=256)
        ims, txts, links = draw_webpage_cyclegan(img_name, energy_path, out_path, '_energy')
        webpage.add_images(ims, txts, links, width=256) 
webpage.save()
Exemplo n.º 15
0
def main():
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break

        a_image_tensor = data['a_image_tensor']  # 3
        b_image_tensor = data['b_image_tensor']  # 3
        b_label_tensor = data['b_label_tensor']  # 18
        a_parsing_tensor = data['a_parsing_tensor']  # 1
        b_parsing_tensor = data['b_parsing_tensor']  # 1
        b_label_show_tensor = data['b_label_show_tensor']
        theta_aff = data['theta_aff_tensor']  # 2
        theta_tps = data['theta_tps_tensor']  # 2
        theta_aff_tps = data['theta_aff_tps_tensor']  # 2
        a_jpg_path = data['a_jpg_path']
        b_jpg_path = data['b_jpg_path']

        input_tensor = torch.cat([a_image_tensor, b_image_tensor, b_label_tensor, a_parsing_tensor, b_parsing_tensor, \
                                  theta_aff, theta_tps, theta_aff_tps], dim=1)
        input_var = Variable(input_tensor.type(torch.cuda.FloatTensor))

        model.eval()
        fake_b = model.inference(input_var)

        # show_image_tensor_1 = torch.cat((a_image_tensor, b_label_show_tensor, b_image_tensor), dim=3)
        # show_image_tensor_2 = torch.cat((a_parsing_rgb_tensor, b_parsing_rgb_tensor, fake_b.data.cpu()), dim=3)
        # show_image_tensor = torch.cat((show_image_tensor_1, show_image_tensor_2), dim=2)
        # test_list = [('a | b | fake_b', tensor2im(show_image_tensor[0]))]

        a_parsing_rgb_tensor = parsingim_2_tensor(
            a_parsing_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)
        b_parsing_rgb_tensor = parsingim_2_tensor(
            b_parsing_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)

        show_image_tensor_1 = torch.cat(
            (a_image_tensor, b_label_show_tensor, b_image_tensor), dim=3)
        show_image_tensor_2 = torch.cat(
            (a_parsing_rgb_tensor, b_parsing_rgb_tensor,
             fake_b.data[0:1, :, :, :].cpu()),
            dim=3)
        show_image_tensor = torch.cat(
            (show_image_tensor_1[0:1, :, :, :], show_image_tensor_2), dim=2)
        test_list = [('a-b-fake_b', tensor2im(show_image_tensor[0])),
                     ('fake_image', util.tensor2im(fake_b.data[0])),
                     ('b_image', util.tensor2im(b_image_tensor[0]))]

        ### save image
        visuals = OrderedDict(test_list)
        visualizer.save_images(webpage, visuals, a_jpg_path[0], b_jpg_path[0])

        if i % 100 == 0:
            print('[%s]process image... %s' % (i, a_jpg_path[0]))

    webpage.save()

    image_dir = webpage.get_image_dir()
    print image_dir