示例#1
0
def setup_data_loader(args, opts):
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    images_path = args.images_dir if args.images_dir is not None else dataset_args[
        'test_source_root']
    print(f"images path: {images_path}")
    align_function = None
    if args.align:
        align_function = run_alignment
    test_dataset = InferenceDataset(
        root=images_path,
        transform=transforms_dict['transform_test'],
        preprocess=align_function,
        split=args.split,
        opts=opts)

    data_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             shuffle=False,
                             num_workers=2,
                             drop_last=True)

    print(f'dataset length: {len(test_dataset)}')

    if args.n_sample is None:
        args.n_sample = len(test_dataset)
    return args, data_loader
示例#2
0
def evaluate(options):
    config = InferenceConfig(options)
    config.FITTING_TYPE = options.numAnchorPlanes

    if options.dataset == '':
        dataset = PlaneDataset(options,
                               config,
                               split='test',
                               random=False,
                               load_semantics=False)
    elif options.dataset == 'occlusion':
        config_dataset = copy.deepcopy(config)
        config_dataset.OCCLUSION = False
        dataset = PlaneDataset(options,
                               config_dataset,
                               split='test',
                               random=False,
                               load_semantics=True)
    elif 'nyu' in options.dataset:
        dataset = NYUDataset(options, config, split='val', random=False)
    elif options.dataset == 'synthia':
        dataset = SynthiaDataset(options, config, split='val', random=False)
    elif options.dataset == 'kitti':
        camera = np.zeros(6)
        camera[0] = 9.842439e+02
        camera[1] = 9.808141e+02
        camera[2] = 6.900000e+02
        camera[3] = 2.331966e+02
        camera[4] = 1242
        camera[5] = 375
        dataset = InferenceDataset(
            options,
            config,
            image_list=glob.glob('../../Data/KITTI/scene_3/*.png'),
            camera=camera)
    elif options.dataset == '7scene':
        camera = np.zeros(6)
        camera[0] = 519
        camera[1] = 519
        camera[2] = 320
        camera[3] = 240
        camera[4] = 640
        camera[5] = 480
        dataset = InferenceDataset(
            options,
            config,
            image_list=glob.glob('../../Data/SevenScene/scene_3/*.png'),
            camera=camera)
    elif options.dataset == 'tanktemple':
        camera = np.zeros(6)
        camera[0] = 0.7
        camera[1] = 0.7
        camera[2] = 0.5
        camera[3] = 0.5
        camera[4] = 1
        camera[5] = 1
        dataset = InferenceDataset(
            options,
            config,
            image_list=glob.glob('../../Data/TankAndTemple/scene_4/*.jpg'),
            camera=camera)
    elif options.dataset == 'make3d':
        camera = np.zeros(6)
        camera[0] = 0.7
        camera[1] = 0.7
        camera[2] = 0.5
        camera[3] = 0.5
        camera[4] = 1
        camera[5] = 1
        dataset = InferenceDataset(
            options,
            config,
            image_list=glob.glob('../../Data/Make3D/*.jpg'),
            camera=camera)
    elif options.dataset == 'popup':
        camera = np.zeros(6)
        camera[0] = 0.7
        camera[1] = 0.7
        camera[2] = 0.5
        camera[3] = 0.5
        camera[4] = 1
        camera[5] = 1
        dataset = InferenceDataset(
            options,
            config,
            image_list=glob.glob('../../Data/PhotoPopup/*.jpg'),
            camera=camera)
    elif options.dataset == 'cross' or options.dataset == 'cross_2':
        image_list = [
            'test/cross_dataset/' + str(c) + '_image.png' for c in range(12)
        ]
        cameras = []
        camera = np.zeros(6)
        camera[0] = 587
        camera[1] = 587
        camera[2] = 320
        camera[3] = 240
        camera[4] = 640
        camera[5] = 480
        for c in range(4):
            cameras.append(camera)
            continue
        camera_kitti = np.zeros(6)
        camera_kitti[0] = 9.842439e+02
        camera_kitti[1] = 9.808141e+02
        camera_kitti[2] = 6.900000e+02
        camera_kitti[3] = 2.331966e+02
        camera_kitti[4] = 1242.0
        camera_kitti[5] = 375.0
        for c in range(2):
            cameras.append(camera_kitti)
            continue
        camera_synthia = np.zeros(6)
        camera_synthia[0] = 133.185088
        camera_synthia[1] = 134.587036
        camera_synthia[2] = 160.000000
        camera_synthia[3] = 96.000000
        camera_synthia[4] = 320
        camera_synthia[5] = 192
        for c in range(2):
            cameras.append(camera_synthia)
            continue
        camera_tanktemple = np.zeros(6)
        camera_tanktemple[0] = 0.7
        camera_tanktemple[1] = 0.7
        camera_tanktemple[2] = 0.5
        camera_tanktemple[3] = 0.5
        camera_tanktemple[4] = 1
        camera_tanktemple[5] = 1
        for c in range(2):
            cameras.append(camera_tanktemple)
            continue
        for c in range(2):
            cameras.append(camera)
            continue
        dataset = InferenceDataset(options,
                                   config,
                                   image_list=image_list,
                                   camera=cameras)
    elif options.dataset == 'selected':
        image_list = glob.glob('test/selected_images/*_image_0.png')
        image_list = [
            filename for filename in image_list
            if '63_image' not in filename and '77_image' not in filename
        ] + [
            filename for filename in image_list
            if '63_image' in filename or '77_image' in filename
        ]
        camera = np.zeros(6)
        camera[0] = 587
        camera[1] = 587
        camera[2] = 320
        camera[3] = 240
        camera[4] = 640
        camera[5] = 480
        dataset = InferenceDataset(options,
                                   config,
                                   image_list=image_list,
                                   camera=camera)
    elif options.dataset == 'comparison':
        image_list = [
            'test/comparison/' + str(index) + '_image_0.png'
            for index in [65, 11, 24]
        ]
        camera = np.zeros(6)
        camera[0] = 587
        camera[1] = 587
        camera[2] = 320
        camera[3] = 240
        camera[4] = 640
        camera[5] = 480
        dataset = InferenceDataset(options,
                                   config,
                                   image_list=image_list,
                                   camera=camera)
    elif 'inference' in options.dataset:
        image_list = glob.glob(options.customDataFolder +
                               '/*.png') + glob.glob(options.customDataFolder +
                                                     '/*.jpg')
        if os.path.exists(options.customDataFolder + '/camera.txt'):
            camera = np.zeros(6)
            with open(options.customDataFolder + '/camera.txt', 'r') as f:
                for line in f:
                    values = [
                        float(token.strip()) for token in line.split(' ')
                        if token.strip() != ''
                    ]
                    for c in range(6):
                        camera[c] = values[c]
                        continue
                    break
                pass
        else:
            camera = [
                filename.replace('.png', '.txt').replace('.jpg', '.txt')
                for filename in image_list
            ]
            pass
        dataset = InferenceDataset(options,
                                   config,
                                   image_list=image_list,
                                   camera=camera)
        pass

    print('the number of images', len(dataset))

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    epoch_losses = []
    data_iterator = tqdm(dataloader, total=len(dataset))

    specified_suffix = options.suffix
    with torch.no_grad():
        detectors = []
        for method in options.methods:
            if method == 'w':
                options.suffix = 'pair_' + specified_suffix if specified_suffix != '' else 'pair'
                detectors.append(('warping',
                                  PlaneRCNNDetector(options,
                                                    config,
                                                    modelType='pair')))
            elif method == 'b':
                options.suffix = specified_suffix if specified_suffix != '' else ''
                detectors.append(('basic',
                                  PlaneRCNNDetector(options,
                                                    config,
                                                    modelType='pair')))
            elif method == 'o':
                options.suffix = 'occlusion_' + specified_suffix if specified_suffix != '' else 'occlusion'
                detectors.append(('occlusion',
                                  PlaneRCNNDetector(options,
                                                    config,
                                                    modelType='occlusion')))
            elif method == 'p':
                detectors.append(
                    ('planenet', PlaneNetDetector(options, config)))
            elif method == 'e':
                detectors.append(
                    ('planerecover', PlaneRecoverDetector(options, config)))
            elif method == 't':
                if 'gt' in options.suffix:
                    detectors.append(
                        ('manhattan_gt',
                         TraditionalDetector(options, config, 'manhattan_gt')))
                else:
                    detectors.append(
                        ('manhattan_pred',
                         TraditionalDetector(options, config,
                                             'manhattan_pred')))
                    pass
            elif method == 'n':
                options.suffix = specified_suffix if specified_suffix != '' else ''
                detectors.append(('non_planar',
                                  DepthDetector(options,
                                                config,
                                                modelType='np')))
            elif method == 'r':
                options.suffix = specified_suffix if specified_suffix != '' else ''
                detectors.append(('refine',
                                  PlaneRCNNDetector(options,
                                                    config,
                                                    modelType='refine')))
            elif method == 's':
                options.suffix = specified_suffix if specified_suffix != '' else ''
                detectors.append(
                    ('refine_single',
                     PlaneRCNNDetector(options,
                                       config,
                                       modelType='refine_single')))
            elif method == 'f':
                options.suffix = specified_suffix if specified_suffix != '' else ''
                detectors.append(('final',
                                  PlaneRCNNDetector(options,
                                                    config,
                                                    modelType='final')))
                pass
            continue
        pass

    if not options.debug:
        for method_name in [detector[0] for detector in detectors]:
            os.system('rm ' + options.test_dir + '/*_' + method_name + '.png')
            continue
        pass

    all_statistics = []
    for name, detector in detectors:
        statistics = [[], [], [], []]
        for sampleIndex, sample in enumerate(data_iterator):
            if options.testingIndex >= 0 and sampleIndex != options.testingIndex:
                if sampleIndex > options.testingIndex:
                    break
                continue
            input_pair = []
            camera = sample[30][0].cuda()
            for indexOffset in [
                    0,
            ]:
                images, image_metas, rpn_match, rpn_bbox, gt_class_ids, gt_boxes, gt_masks, gt_parameters, gt_depth, extrinsics, planes, gt_segmentation = sample[
                    indexOffset +
                    0].cuda(), sample[indexOffset + 1].numpy(), sample[
                        indexOffset +
                        2].cuda(), sample[indexOffset + 3].cuda(), sample[
                            indexOffset +
                            4].cuda(), sample[indexOffset + 5].cuda(), sample[
                                indexOffset +
                                6].cuda(), sample[indexOffset + 7].cuda(
                                ), sample[indexOffset + 8].cuda(), sample[
                                    indexOffset + 9].cuda(), sample[
                                        indexOffset +
                                        10].cuda(), sample[indexOffset +
                                                           11].cuda()

                masks = (
                    gt_segmentation == torch.arange(gt_segmentation.max() +
                                                    1).cuda().view(-1, 1,
                                                                   1)).float()
                input_pair.append({
                    'image': images,
                    'depth': gt_depth,
                    'bbox': gt_boxes,
                    'extrinsics': extrinsics,
                    'segmentation': gt_segmentation,
                    'camera': camera,
                    'plane': planes[0],
                    'masks': masks,
                    'mask': gt_masks
                })
                continue

            if sampleIndex >= options.numTestingImages:
                break

            with torch.no_grad():
                detection_pair = detector.detect(sample)
                pass

            if options.dataset == 'rob':
                depth = detection_pair[0]['depth'].squeeze().detach().cpu(
                ).numpy()
                os.system('rm ' +
                          image_list[sampleIndex].replace('color', 'depth'))
                depth_rounded = np.round(depth * 256)
                depth_rounded[np.logical_or(depth_rounded < 0,
                                            depth_rounded >= 256 * 256)] = 0
                cv2.imwrite(
                    image_list[sampleIndex].replace('color', 'depth').replace(
                        'jpg', 'png'), depth_rounded.astype(np.uint16))
                continue

            if 'inference' not in options.dataset:
                for c in range(len(input_pair)):
                    evaluateBatchDetection(
                        options,
                        config,
                        input_pair[c],
                        detection_pair[c],
                        statistics=statistics,
                        printInfo=options.debug,
                        evaluate_plane=options.dataset == '')
                    continue
            else:
                for c in range(len(detection_pair)):
                    np.save(
                        options.test_dir + '/' + str(sampleIndex % 500) +
                        '_plane_parameters_' + str(c) + '.npy',
                        detection_pair[c]['detection'][:, 6:9])
                    np.save(
                        options.test_dir + '/' + str(sampleIndex % 500) +
                        '_plane_masks_' + str(c) + '.npy',
                        detection_pair[c]['masks'][:, 80:560])
                    continue
                pass

            if sampleIndex < 30 or options.debug or options.dataset != '':
                visualizeBatchPair(options,
                                   config,
                                   input_pair,
                                   detection_pair,
                                   indexOffset=sampleIndex % 500,
                                   suffix='_' + name + options.modelType,
                                   write_ply=options.testingIndex >= 0,
                                   write_new_view=options.testingIndex >= 0
                                   and 'occlusion' in options.suffix)
                pass
            if sampleIndex >= options.numTestingImages:
                break
            continue
        if 'inference' not in options.dataset:
            options.keyname = name
            printStatisticsDetection(options, statistics)
            all_statistics.append(statistics)
            pass
        continue
    if 'inference' not in options.dataset:
        if options.debug and len(detectors) > 1:
            all_statistics = np.concatenate([
                np.arange(len(all_statistics[0][0])).reshape((-1, 1)),
            ] + [np.array(statistics[3]) for statistics in all_statistics],
                                            axis=-1)
            print(all_statistics.astype(np.int32))
            pass
        if options.testingIndex == -1:
            np.save('logs/all_statistics.npy', all_statistics)
            pass
        pass
    return
def run():
    test_opts = TestOptions().parse()

    if test_opts.resize_factors is not None:
        factors = test_opts.resize_factors.split(',')
        assert len(
            factors
        ) == 1, "When running inference, please provide a single downsampling factor!"
        mixed_path_results = os.path.join(
            test_opts.exp_dir, 'style_mixing',
            'downsampling_{}'.format(test_opts.resize_factors))
    else:
        mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing')
    os.makedirs(mixed_path_results, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    if 'learn_in_w' not in opts:
        opts['learn_in_w'] = False
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=True)

    latent_mask = [int(l) for l in opts.latent_mask.split(",")]
    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_i = 0
    for input_batch in tqdm(dataloader):
        if global_i > opts.n_images:
            break
        with torch.no_grad():
            input_batch = input_batch.cuda()
            for image_idx, input_image in enumerate(input_batch):
                # generate random vectors to inject into input image
                vecs_to_inject = np.random.randn(opts.n_outputs_to_generate,
                                                 512).astype('float32')
                multi_modal_outputs = []
                for vec_to_inject in vecs_to_inject:
                    cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to(
                        "cuda")
                    # get latent vector to inject into our input image
                    _, latent_to_inject = net(cur_vec,
                                              input_code=True,
                                              return_latents=True)
                    # get output image with injected style vector
                    res = net(input_image.unsqueeze(0).to("cuda").float(),
                              latent_mask=latent_mask,
                              inject_latent=latent_to_inject,
                              alpha=opts.mix_alpha)
                    multi_modal_outputs.append(res[0])

                # visualize multi modal outputs
                input_im_path = dataset.paths[global_i]
                image = input_batch[image_idx]
                input_image = log_input_image(image, opts)
                res = np.array(input_image.resize((256, 256)))
                for output in multi_modal_outputs:
                    output = tensor2im(output)
                    res = np.concatenate(
                        [res, np.array(output.resize((256, 256)))], axis=1)
                Image.fromarray(res).save(
                    os.path.join(mixed_path_results,
                                 os.path.basename(input_im_path)))
                global_i += 1
示例#4
0
def run():
	test_opts = TestOptions().parse()

	out_path_results = os.path.join(test_opts.exp_dir, 'inference_side_by_side')
	os.makedirs(out_path_results, exist_ok=True)

	# update test options with options used during training
	ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
	opts = ckpt['opts']
	opts.update(vars(test_opts))
	opts = Namespace(**opts)

	net = pSp(opts)
	net.eval()
	net.cuda()

	age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]

	print(f'Loading dataset for {opts.dataset_type}')
	dataset_args = data_configs.DATASETS[opts.dataset_type]
	transforms_dict = dataset_args['transforms'](opts).get_transforms()
	dataset = InferenceDataset(root=opts.data_path,
							   transform=transforms_dict['transform_inference'],
							   opts=opts,
							   return_path=True)
	dataloader = DataLoader(dataset,
							batch_size=opts.test_batch_size,
							shuffle=False,
							num_workers=int(opts.test_workers),
							drop_last=False)

	if opts.n_images is None:
		opts.n_images = len(dataset)

	global_time = []
	global_i = 0
	for input_batch, image_paths in tqdm(dataloader):
		if global_i >= opts.n_images:
			break
		batch_results = {}
		for idx, age_transformer in enumerate(age_transformers):
			with torch.no_grad():
				input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch]
				input_age_batch = torch.stack(input_age_batch)
				input_cuda = input_age_batch.cuda().float()
				tic = time.time()
				result_batch = run_on_batch(input_cuda, net, opts)
				toc = time.time()
				global_time.append(toc - tic)

				resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
				for i in range(len(input_batch)):
					result = tensor2im(result_batch[i])
					im_path = image_paths[i]
					input_im = log_image(input_batch[i], opts)
					if im_path not in batch_results.keys():
						batch_results[im_path] = np.array(input_im.resize(resize_amount))
					batch_results[im_path] = np.concatenate([batch_results[im_path],
															 np.array(result.resize(resize_amount))],
															axis=1)

		for im_path, res in batch_results.items():
			image_name = os.path.basename(im_path)
			im_save_path = os.path.join(out_path_results, image_name)
			Image.fromarray(np.array(res)).save(im_save_path)
			global_i += 1
def run():
    test_opts = TestOptions().parse()

    if test_opts.resize_factors is not None:
        assert len(
            test_opts.resize_factors.split(',')
        ) == 1, "When running inference, provide a single downsampling factor!"
        out_path_results = os.path.join(
            test_opts.exp_dir, 'inference_results',
            'downsampling_{}'.format(test_opts.resize_factors))
        out_path_coupled = os.path.join(
            test_opts.exp_dir, 'inference_coupled',
            'downsampling_{}'.format(test_opts.resize_factors))
    else:
        out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
        out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')

    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    if 'learn_in_w' not in opts:
        opts['learn_in_w'] = False
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=True)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = run_on_batch(input_cuda, net, opts)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(opts.test_batch_size):
            result = tensor2im(result_batch[i])
            im_path = dataset.paths[global_i]

            if opts.couple_outputs or global_i % 100 == 0:
                input_im = log_input_image(input_batch[i], opts)
                resize_amount = (256, 256) if opts.resize_outputs else (1024,
                                                                        1024)
                if opts.resize_factors is not None:
                    # for super resolution, save the original, down-sampled, and output
                    source = Image.open(im_path)
                    res = np.concatenate([
                        np.array(source.resize(resize_amount)),
                        np.array(
                            input_im.resize(resize_amount,
                                            resample=Image.NEAREST)),
                        np.array(result.resize(resize_amount))
                    ],
                                         axis=1)
                else:
                    # otherwise, save the original and output
                    res = np.concatenate([
                        np.array(input_im.resize(resize_amount)),
                        np.array(result.resize(resize_amount))
                    ],
                                         axis=1)
                Image.fromarray(res).save(
                    os.path.join(out_path_coupled, os.path.basename(im_path)))

            im_save_path = os.path.join(out_path_results,
                                        os.path.basename(im_path))
            Image.fromarray(np.array(result)).save(im_save_path)

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
def run():
    test_opts = TestOptions().parse()

    out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(root=opts.data_path,
                               transform=transforms_dict['transform_inference'],
                               opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = net(net.latent_avg.unsqueeze(0),
                    input_code=True,
                    randomize_noise=False,
                    return_latents=False,
                    average_code=True)[0]
    avg_image = avg_image.to('cuda').float().detach()
    if opts.dataset_type == "cars_encode":
        avg_image = avg_image[:, 32:224, :]
    tensor2im(avg_image).save(os.path.join(opts.exp_dir, 'avg_image.jpg'))

    if opts.dataset_type == "cars_encode":
        resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
    else:
        resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break

        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch, result_latents = run_on_batch(input_cuda, net, opts, avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)]
            im_path = dataset.paths[global_i]

            # save step-by-step results side-by-side
            input_im = tensor2im(input_batch[i])
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)

            Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
    os.makedirs(out_path_results, exist_ok=True)

    # load model used for initializing encoder bootstrapping
    ckpt = torch.load(test_opts.model_1_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_1_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net1 = pSp(opts)
    else:
        net1 = e4e(opts)
    net1.eval()
    net1.cuda()

    # load model used for translating input image after initialization
    ckpt = torch.load(test_opts.model_2_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_2_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net2 = pSp(opts)
    else:
        net2 = e4e(opts)
    net2.eval()
    net2.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = get_average_image(net1, opts)

    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size,
                                                            opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = run_on_batch(input_cuda, net1, net2, opts,
                                        avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [
                tensor2im(result_batch[i][iter_idx])
                for iter_idx in range(opts.n_iters_per_batch + 1)
            ]
            im_path = dataset.paths[global_i]

            input_im = tensor2im(input_batch[i])

            # save step-by-step results side-by-side
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate(
                    [res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
            Image.fromarray(res).save(
                os.path.join(out_path_results, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
示例#8
0
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
    out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    age_transformers = [
        AgeTransformer(target_age=age) for age in opts.target_age.split(',')
    ]

    print(f'Loading dataset for {opts.dataset_type}')
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_time = []
    for age_transformer in age_transformers:
        print(f"Running on target age: {age_transformer.target_age}")
        global_i = 0
        for input_batch in tqdm(dataloader):
            if global_i >= opts.n_images:
                break
            with torch.no_grad():
                input_age_batch = [
                    age_transformer(img.cpu()).to('cuda')
                    for img in input_batch
                ]
                input_age_batch = torch.stack(input_age_batch)
                input_cuda = input_age_batch.cuda().float()
                tic = time.time()
                result_batch = run_on_batch(input_cuda, net, opts)
                toc = time.time()
                global_time.append(toc - tic)

                for i in range(len(input_batch)):
                    result = tensor2im(result_batch[i])
                    im_path = dataset.paths[global_i]

                    if opts.couple_outputs or global_i % 100 == 0:
                        input_im = log_image(input_batch[i], opts)
                        resize_amount = (
                            256, 256) if opts.resize_outputs else (1024, 1024)
                        res = np.concatenate([
                            np.array(input_im.resize(resize_amount)),
                            np.array(result.resize(resize_amount))
                        ],
                                             axis=1)
                        age_out_path_coupled = os.path.join(
                            out_path_coupled, age_transformer.target_age)
                        os.makedirs(age_out_path_coupled, exist_ok=True)
                        Image.fromarray(res).save(
                            os.path.join(age_out_path_coupled,
                                         os.path.basename(im_path)))

                    age_out_path_results = os.path.join(
                        out_path_results, age_transformer.target_age)
                    os.makedirs(age_out_path_results, exist_ok=True)
                    image_name = os.path.basename(im_path)
                    im_save_path = os.path.join(age_out_path_results,
                                                image_name)
                    Image.fromarray(np.array(
                        result.resize(resize_amount))).save(im_save_path)
                    global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
示例#9
0
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir,
                                    'reference_guided_inference')
    os.makedirs(out_path_results, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    age_transformers = [
        AgeTransformer(target_age=age) for age in opts.target_age.split(',')
    ]

    print(f'Loading dataset for {opts.dataset_type}')
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()

    source_dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    source_dataloader = DataLoader(source_dataset,
                                   batch_size=opts.test_batch_size,
                                   shuffle=False,
                                   num_workers=int(opts.test_workers),
                                   drop_last=False)

    ref_dataset = InferenceDataset(
        paths_list=opts.ref_images_paths_file,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    ref_dataloader = DataLoader(ref_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=1,
                                drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(source_dataset)

    for age_transformer in age_transformers:
        target_age = age_transformer.target_age
        print(f"Running on target age: {target_age}")
        age_save_path = os.path.join(out_path_results, str(target_age))
        os.makedirs(age_save_path, exist_ok=True)
        global_i = 0
        for i, source_batch in enumerate(tqdm(source_dataloader)):
            if global_i >= opts.n_images:
                break
            results_per_source = {idx: [] for idx in range(len(source_batch))}
            with torch.no_grad():
                for ref_batch in ref_dataloader:
                    source_batch = source_batch.cuda().float()
                    ref_batch = ref_batch.cuda().float()
                    source_input_age_batch = [
                        age_transformer(img.cpu()).to('cuda')
                        for img in source_batch
                    ]
                    source_input_age_batch = torch.stack(
                        source_input_age_batch)

                    # compute w+ of ref images to be injected for style-mixing
                    ref_latents = net.pretrained_encoder(
                        ref_batch) + net.latent_avg

                    # run age transformation on source images with style-mixing
                    res_batch_mixed = run_on_batch(
                        source_input_age_batch,
                        net,
                        opts,
                        latent_to_inject=ref_latents)

                    # store results
                    for idx in range(len(source_batch)):
                        results_per_source[idx].append(
                            [ref_batch[0], res_batch_mixed[idx]])

                # save results
                resize_amount = (256, 256) if opts.resize_outputs else (1024,
                                                                        1024)
                for image_idx, image_results in results_per_source.items():
                    input_im_path = source_dataset.paths[global_i]
                    image = source_batch[image_idx]
                    input_image = log_image(image, opts)
                    # initialize results image
                    ref_inputs = np.zeros_like(
                        input_image.resize(resize_amount))
                    mixing_results = np.array(
                        input_image.resize(resize_amount))
                    for ref_idx in range(len(image_results)):
                        ref_input, mixing_result = image_results[ref_idx]
                        ref_input = log_image(ref_input, opts)
                        mixing_result = log_image(mixing_result, opts)
                        # append current results
                        ref_inputs = np.concatenate([
                            ref_inputs,
                            np.array(ref_input.resize(resize_amount))
                        ],
                                                    axis=1)
                        mixing_results = np.concatenate([
                            mixing_results,
                            np.array(mixing_result.resize(resize_amount))
                        ],
                                                        axis=1)
                    res = np.concatenate([ref_inputs, mixing_results], axis=0)
                    save_path = os.path.join(age_save_path,
                                             os.path.basename(input_im_path))
                    Image.fromarray(res).save(save_path)
                    global_i += 1
def run():
    """
    This script can be used to perform inversion and editing. Please note that this script supports editing using
    only the ReStyle-e4e model and currently supports editing using three edit directions found using InterFaceGAN
    (age, smile, and pose) on the faces domain.
    For performing the edits please provide the arguments `--edit_directions` and `--factor_ranges`. For example,
    setting these values to be `--edit_directions=age,smile,pose` and `--factor_ranges=5,5,5` will use a lambda range
    between -5 and 5 for each of the attributes. These should be comma-separated lists of the same length. You may
    get better results by playing around with the factor ranges for each edit.
    """
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'editing_results')
    out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled')

    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)
    net = e4e(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    if opts.dataset_type != "ffhq_encode":
        raise ValueError(
            "Editing script only supports edits on the faces domain!")
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    latent_editor = LatentEditor(net.decoder)
    opts.edit_directions = opts.edit_directions.split(',')
    opts.factor_ranges = [
        int(factor) for factor in opts.factor_ranges.split(',')
    ]
    if len(opts.edit_directions) != len(opts.factor_ranges):
        raise ValueError(
            "Invalid edit directions and factor ranges. Please provide a single factor range for each"
            f"edit direction. Given: {opts.edit_directions} and {opts.factor_ranges}"
        )

    avg_image = get_average_image(net, opts)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = edit_batch(input_cuda, net, avg_image,
                                      latent_editor, opts)
            toc = time.time()
            global_time.append(toc - tic)

        resize_amount = (256,
                         256) if opts.resize_outputs else (opts.output_size,
                                                           opts.output_size)
        for i in range(input_batch.shape[0]):

            im_path = dataset.paths[global_i]
            results = result_batch[i]

            inversion = results.pop('inversion')
            input_im = tensor2im(input_batch[i])

            all_edit_results = []
            for edit_name, edit_res in results.items():
                res = np.array(
                    input_im.resize(resize_amount))  # set the input image
                res = np.concatenate(
                    [res, np.array(inversion.resize(resize_amount))],
                    axis=1)  # set the inversion
                for result in edit_res:
                    res = np.concatenate(
                        [res, np.array(result.resize(resize_amount))], axis=1)
                res_im = Image.fromarray(res)
                all_edit_results.append(res_im)

                edit_save_dir = os.path.join(out_path_results, edit_name)
                os.makedirs(edit_save_dir, exist_ok=True)
                res_im.save(
                    os.path.join(edit_save_dir, os.path.basename(im_path)))

            # save final concatenated result if all factor ranges are equal
            if opts.factor_ranges.count(opts.factor_ranges[0]) == len(
                    opts.factor_ranges):
                coupled_res = np.concatenate(all_edit_results, axis=0)
                im_save_path = os.path.join(out_path_coupled,
                                            os.path.basename(im_path))
                Image.fromarray(coupled_res).save(im_save_path)

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)