Beispiel #1
0
 def parse_and_log_images(self,
                          id_logs,
                          x,
                          y,
                          y_hat,
                          title,
                          subscript=None,
                          display_count=2):
     im_data = []
     for i in range(display_count):
         if type(y_hat) == dict:
             output_face = [[
                 common.tensor2im(y_hat[i][iter_idx][0]),
                 y_hat[i][iter_idx][1]
             ] for iter_idx in range(len(y_hat[i]))]
         else:
             output_face = [common.tensor2im(y_hat[i])]
         cur_im_data = {
             'input_face': common.tensor2im(x[i]),
             'target_face': common.tensor2im(y[i]),
             'output_face': output_face,
         }
         if id_logs is not None:
             for key in id_logs[i]:
                 cur_im_data[key] = id_logs[i][key]
         im_data.append(cur_im_data)
     self.log_images(title, im_data=im_data, subscript=subscript)
Beispiel #2
0
 def parse_and_log_images(self,
                          id_logs,
                          x,
                          y,
                          y_hat,
                          title,
                          subscript=None,
                          display_count=2):
     im_data = []
     normalize_source = transforms.Normalize in [
         transform.__class__
         for transform in self.train_dataset.source_transform.transforms
     ]
     normalize_target = transforms.Normalize in [
         transform.__class__
         for transform in self.train_dataset.target_transform.transforms
     ]
     for i in range(display_count):
         cur_im_data = {
             'input_image':
             common.tensor2im(x[i], normalize=normalize_source),
             'target_image':
             common.tensor2im(y[i], normalize=normalize_target),
             'output_image':
             common.tensor2im(y_hat[i], normalize=normalize_target),
         }
         if id_logs is not None:
             for key in id_logs[i]:
                 cur_im_data[key] = id_logs[i][key]
         im_data.append(cur_im_data)
     self.log_images(title, im_data=im_data, subscript=subscript)
def mix_edit():
    f1 = st.sidebar.file_uploader("input image1")
    f2 = st.sidebar.file_uploader("input image2")

    active_layer = st.sidebar.slider("use image2 layer", 0, 17, (12, 17))

    attr_params = setting_attribute_slider()

    cols = st.beta_columns(2)
    if not f1 == None and not f2 == None:
        tfile1 = tempfile.NamedTemporaryFile(delete=False)
        tfile1.write(f1.read())
        image1 = run_alignment(tfile1.name)

        tfile2 = tempfile.NamedTemporaryFile(delete=False)
        tfile2.write(f2.read())
        image2 = run_alignment(tfile2.name)

        if not image1 == None and not image2 == None:
            with torch.no_grad():
                transformed_image1 = img_transforms(image1)
                images1, latents1 = net(
                    transformed_image1.unsqueeze(0).to('cuda').float(),
                    randomize_noise=False,
                    return_latents=True)
                result_image1, latent1 = images1[0], latents1[0]

                transformed_image2 = img_transforms(image2)
                images2, latents2 = net(
                    transformed_image2.unsqueeze(0).to('cuda').float(),
                    randomize_noise=False,
                    return_latents=True)
                result_image2, latent2 = images2[0], latents2[0]
                with cols[0]:
                    st.write("input image")
                    st.image(image1)
                    st.image(image2)
                with cols[1]:
                    st.write("inversion image")
                    st.image(tensor2im(result_image1))
                    st.image(tensor2im(result_image2))
                edit_latent = latent1
                edit_latent[active_layer[0]:active_layer[1]] = latent2[
                    active_layer[0]:active_layer[1]]
                for attr in attr_params.keys():
                    edit_latent += attr_params[attr][0]

                generator = net.decoder
                edit_images, _ = generator([edit_latent.unsqueeze(0)],
                                           input_is_latent=True,
                                           randomize_noise=False,
                                           return_latents=True)

            st.image(tensor2im(edit_images[0]))
Beispiel #4
0
	def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2):
		im_data = []
		for i in range(display_count):
			cur_im_data = {
				'input_face': common.log_input_image(x[i], self.opts),
				'target_face': common.tensor2im(y[i]),
				'output_face': common.tensor2im(y_hat[i]),
			}
			if id_logs is not None:
				for key in id_logs[i]:
					cur_im_data[key] = id_logs[i][key]
			im_data.append(cur_im_data)
		self.log_images(title, im_data=im_data, subscript=subscript)
Beispiel #5
0
def get_coupled_results(result_batch, transformed_image):
    """
    Visualize output images from left to right (the input image is on the right)
    """
    result_tensors = result_batch[0]  # there's one image in our batch
    result_images = [
        tensor2im(result_tensors[iter_idx]) for iter_idx in range(NUM_STEPS)
    ]
    input_im = tensor2im(transformed_image)
    res = np.array(result_images[0].resize(resize_amount))
    for idx, result in enumerate(result_images[1:]):
        res = np.concatenate([res, np.array(result.resize(resize_amount))],
                             axis=1)
    res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
    res = Image.fromarray(res)
    return res, result_images
 def get_final_output(result_batch, resize_amount,
                      display_intermediate_results, opts):
     result_tensors = result_batch[0]  # there's one image in our batch
     if display_intermediate_results:
         result_images = [
             tensor2im(result_tensors[iter_idx])
             for iter_idx in range(opts.n_iters_per_batch)
         ]
     else:
         result_images = [tensor2im(result_tensors[-1])]
     res = np.array(result_images[0].resize(resize_amount))
     for idx, result in enumerate(result_images[1:]):
         res = np.concatenate(
             [res, np.array(result.resize(resize_amount))], axis=1)
     res = Image.fromarray(res)
     return res
 def _latents_to_image(self, all_latents):
     sample_results = {}
     with torch.no_grad():
         for idx, sample_latents in enumerate(all_latents):
             images, _ = self.generator([sample_latents], randomize_noise=False, input_is_latent=True)
             sample_results[idx] = [tensor2im(image) for image in images]
     return sample_results
Beispiel #8
0
 def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts):
     im_data = []
     column_names = ["Source", "Target", "Output"]
     if id_logs is not None:
         column_names.append("ID Diff Output to Target")
     for i in range(len(x)):
         cur_im_data = [
             wandb.Image(common.log_input_image(x[i], opts)),
             wandb.Image(common.tensor2im(y[i])),
             wandb.Image(common.tensor2im(y_hat[i])),
         ]
         if id_logs is not None:
             cur_im_data.append(id_logs[i]["diff_target"])
         im_data.append(cur_im_data)
     outputs_table = wandb.Table(data=im_data, columns=column_names)
     wandb.log(
         {f"{prefix.title()} Step {step} Output Samples": outputs_table})
Beispiel #9
0
    def predict(self, model, image):
        opts = self.opts[model]
        opts = Namespace(**opts)
        pprint.pprint(opts)

        net = pSp(opts)
        net.eval()
        net.cuda()
        print('Model successfully loaded!')

        original_image = Image.open(str(image))
        if opts.label_nc == 0:
            original_image = original_image.convert("RGB")
        else:
            original_image = original_image.convert("L")
        original_image.resize(
            (self.opts[model]['output_size'], self.opts[model]['output_size']))

        # Align Image
        if model not in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
            input_image = self.run_alignment(str(image))
        else:
            input_image = original_image

        img_transforms = self.transforms[model]
        transformed_image = img_transforms(input_image)

        if model in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
            latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
        else:
            latent_mask = None

        with torch.no_grad():
            result_image = run_on_batch(transformed_image.unsqueeze(0), net,
                                        latent_mask)[0]
        input_vis_image = log_input_image(transformed_image, opts)
        output_image = tensor2im(result_image)

        if model == "celebs_super_resolution":
            res = np.concatenate([
                np.array(
                    input_vis_image.resize((self.opts[model]['output_size'],
                                            self.opts[model]['output_size']))),
                np.array(
                    output_image.resize((self.opts[model]['output_size'],
                                         self.opts[model]['output_size'])))
            ],
                                 axis=1)
        else:
            res = np.array(
                output_image.resize((self.opts[model]['output_size'],
                                     self.opts[model]['output_size'])))

        out_path = Path(tempfile.mkdtemp()) / "out.png"
        Image.fromarray(np.array(res)).save(str(out_path))
        return out_path
 def _latents_to_image(self, latents):
     with torch.no_grad():
         images, _ = self.generator([latents],
                                    randomize_noise=False,
                                    input_is_latent=True)
         if self.is_cars:
             images = images[:, :, 64:448, :]  # 512x512 -> 384x512
     horizontal_concat_image = torch.cat(list(images), 2)
     final_image = tensor2im(horizontal_concat_image)
     return final_image
Beispiel #11
0
def run(encoder, decoder, latent_avg, original):
    """Encode and decode an image"""

    input_image = tforms(original).to(device)

    with torch.no_grad():

        codes = encoder(input_image.unsqueeze(0).float())
        codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)
        image, latent = decoder([codes], input_is_latent=True)
        out_im = image.squeeze()

    return tensor2im(out_im)
def single_edit():
    f = st.sidebar.file_uploader("input image")

    attr_params = setting_attribute_slider()

    cols = st.beta_columns(2)
    if not f == None:
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(f.read())
        image = run_alignment(tfile.name)

        if not image == None:
            with torch.no_grad():
                transformed_image = img_transforms(image)
                images, latents = net(
                    transformed_image.unsqueeze(0).to('cuda').float(),
                    randomize_noise=False,
                    return_latents=True)
                result_image, latent = images[0], latents[0]
                with cols[0]:
                    st.write("input image")
                    st.image(image)
                with cols[1]:
                    st.write("inversion image")
                    st.image(tensor2im(result_image))
                edit_latent = latent
                for attr in attr_params.keys():
                    edit_latent += attr_params[attr][0]

                generator = net.decoder
                edit_images, _ = generator([edit_latent.unsqueeze(0)],
                                           input_is_latent=True,
                                           randomize_noise=False,
                                           return_latents=True)

            st.image(tensor2im(edit_images[0]))
def edit_batch(inputs, net, avg_image, latent_editor, opts):
    y_hat, latents = get_inversions_on_batch(inputs, net, avg_image, opts)
    # store all results for each sample, split by the edit direction
    results = {
        idx: {
            'inversion': tensor2im(y_hat[idx])
        }
        for idx in range(len(inputs))
    }
    for edit_direction, factor_range in zip(opts.edit_directions,
                                            opts.factor_ranges):
        edit_res = latent_editor.apply_interfacegan(
            latents=latents,
            direction=edit_direction,
            factor_range=(-1 * factor_range, factor_range))
        # store the results for each sample
        for idx, sample_res in edit_res.items():
            results[idx][edit_direction] = sample_res
    return results
Beispiel #14
0
    def predict(self, image, target_age="default"):
        net = pSp(self.opts)
        net.eval()
        if torch.cuda.is_available():
            net.cuda()

        # align image
        aligned_image = run_alignment(str(image))
        aligned_image.resize((256, 256))

        input_image = self.transform(aligned_image)

        if target_age == "default":
            target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
            age_transformers = [
                AgeTransformer(target_age=age) for age in target_ages
            ]
        else:
            age_transformers = [AgeTransformer(target_age=target_age)]

        results = np.array(aligned_image.resize((1024, 1024)))
        all_imgs = []
        for age_transformer in age_transformers:
            print(f"Running on target age: {age_transformer.target_age}")
            with torch.no_grad():
                input_image_age = [
                    age_transformer(input_image.cpu()).to("cuda")
                ]
                input_image_age = torch.stack(input_image_age)
                result_tensor = run_on_batch(input_image_age, net)[0]
                result_image = tensor2im(result_tensor)
                all_imgs.append(result_image)
                results = np.concatenate([results, result_image], axis=1)

        if target_age == "default":
            out_path = Path(tempfile.mkdtemp()) / "output.gif"
            imageio.mimwrite(str(out_path), all_imgs, duration=0.3)
        else:
            out_path = Path(tempfile.mkdtemp()) / "output.png"
            imageio.imwrite(str(out_path), all_imgs[0])
        return out_path
def save_image(img, save_dir, idx):
    result = tensor2im(img)
    im_save_path = os.path.join(save_dir, f"{idx:06d}.jpg")
    Image.fromarray(np.array(result)).save(im_save_path)
def run():
    test_opts = TestOptions().parse()

    if test_opts.resize_factors is not None:
        assert len(
            test_opts.resize_factors.split(',')
        ) == 1, "When running inference, provide a single downsampling factor!"
        out_path_results = os.path.join(
            test_opts.exp_dir, 'inference_results',
            'downsampling_{}'.format(test_opts.resize_factors))
        out_path_coupled = os.path.join(
            test_opts.exp_dir, 'inference_coupled',
            'downsampling_{}'.format(test_opts.resize_factors))
    else:
        out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
        out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')

    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    if 'learn_in_w' not in opts:
        opts['learn_in_w'] = False
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=True)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = run_on_batch(input_cuda, net, opts)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(opts.test_batch_size):
            result = tensor2im(result_batch[i])
            im_path = dataset.paths[global_i]

            if opts.couple_outputs or global_i % 100 == 0:
                input_im = log_input_image(input_batch[i], opts)
                resize_amount = (256, 256) if opts.resize_outputs else (1024,
                                                                        1024)
                if opts.resize_factors is not None:
                    # for super resolution, save the original, down-sampled, and output
                    source = Image.open(im_path)
                    res = np.concatenate([
                        np.array(source.resize(resize_amount)),
                        np.array(
                            input_im.resize(resize_amount,
                                            resample=Image.NEAREST)),
                        np.array(result.resize(resize_amount))
                    ],
                                         axis=1)
                else:
                    # otherwise, save the original and output
                    res = np.concatenate([
                        np.array(input_im.resize(resize_amount)),
                        np.array(result.resize(resize_amount))
                    ],
                                         axis=1)
                Image.fromarray(res).save(
                    os.path.join(out_path_coupled, os.path.basename(im_path)))

            im_save_path = os.path.join(out_path_results,
                                        os.path.basename(im_path))
            Image.fromarray(np.array(result)).save(im_save_path)

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
def run():
    """
    This script can be used to perform inversion and editing. Please note that this script supports editing using
    only the ReStyle-e4e model and currently supports editing using three edit directions found using InterFaceGAN
    (age, smile, and pose) on the faces domain.
    For performing the edits please provide the arguments `--edit_directions` and `--factor_ranges`. For example,
    setting these values to be `--edit_directions=age,smile,pose` and `--factor_ranges=5,5,5` will use a lambda range
    between -5 and 5 for each of the attributes. These should be comma-separated lists of the same length. You may
    get better results by playing around with the factor ranges for each edit.
    """
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'editing_results')
    out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled')

    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)
    net = e4e(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    if opts.dataset_type != "ffhq_encode":
        raise ValueError(
            "Editing script only supports edits on the faces domain!")
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    latent_editor = LatentEditor(net.decoder)
    opts.edit_directions = opts.edit_directions.split(',')
    opts.factor_ranges = [
        int(factor) for factor in opts.factor_ranges.split(',')
    ]
    if len(opts.edit_directions) != len(opts.factor_ranges):
        raise ValueError(
            "Invalid edit directions and factor ranges. Please provide a single factor range for each"
            f"edit direction. Given: {opts.edit_directions} and {opts.factor_ranges}"
        )

    avg_image = get_average_image(net, opts)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = edit_batch(input_cuda, net, avg_image,
                                      latent_editor, opts)
            toc = time.time()
            global_time.append(toc - tic)

        resize_amount = (256,
                         256) if opts.resize_outputs else (opts.output_size,
                                                           opts.output_size)
        for i in range(input_batch.shape[0]):

            im_path = dataset.paths[global_i]
            results = result_batch[i]

            inversion = results.pop('inversion')
            input_im = tensor2im(input_batch[i])

            all_edit_results = []
            for edit_name, edit_res in results.items():
                res = np.array(
                    input_im.resize(resize_amount))  # set the input image
                res = np.concatenate(
                    [res, np.array(inversion.resize(resize_amount))],
                    axis=1)  # set the inversion
                for result in edit_res:
                    res = np.concatenate(
                        [res, np.array(result.resize(resize_amount))], axis=1)
                res_im = Image.fromarray(res)
                all_edit_results.append(res_im)

                edit_save_dir = os.path.join(out_path_results, edit_name)
                os.makedirs(edit_save_dir, exist_ok=True)
                res_im.save(
                    os.path.join(edit_save_dir, os.path.basename(im_path)))

            # save final concatenated result if all factor ranges are equal
            if opts.factor_ranges.count(opts.factor_ranges[0]) == len(
                    opts.factor_ranges):
                coupled_res = np.concatenate(all_edit_results, axis=0)
                im_save_path = os.path.join(out_path_coupled,
                                            os.path.basename(im_path))
                Image.fromarray(coupled_res).save(im_save_path)

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
Beispiel #18
0
def run():
	test_opts = TestOptions().parse()

	out_path_results = os.path.join(test_opts.exp_dir, 'inference_side_by_side')
	os.makedirs(out_path_results, exist_ok=True)

	# update test options with options used during training
	ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
	opts = ckpt['opts']
	opts.update(vars(test_opts))
	opts = Namespace(**opts)

	net = pSp(opts)
	net.eval()
	net.cuda()

	age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]

	print(f'Loading dataset for {opts.dataset_type}')
	dataset_args = data_configs.DATASETS[opts.dataset_type]
	transforms_dict = dataset_args['transforms'](opts).get_transforms()
	dataset = InferenceDataset(root=opts.data_path,
							   transform=transforms_dict['transform_inference'],
							   opts=opts,
							   return_path=True)
	dataloader = DataLoader(dataset,
							batch_size=opts.test_batch_size,
							shuffle=False,
							num_workers=int(opts.test_workers),
							drop_last=False)

	if opts.n_images is None:
		opts.n_images = len(dataset)

	global_time = []
	global_i = 0
	for input_batch, image_paths in tqdm(dataloader):
		if global_i >= opts.n_images:
			break
		batch_results = {}
		for idx, age_transformer in enumerate(age_transformers):
			with torch.no_grad():
				input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch]
				input_age_batch = torch.stack(input_age_batch)
				input_cuda = input_age_batch.cuda().float()
				tic = time.time()
				result_batch = run_on_batch(input_cuda, net, opts)
				toc = time.time()
				global_time.append(toc - tic)

				resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
				for i in range(len(input_batch)):
					result = tensor2im(result_batch[i])
					im_path = image_paths[i]
					input_im = log_image(input_batch[i], opts)
					if im_path not in batch_results.keys():
						batch_results[im_path] = np.array(input_im.resize(resize_amount))
					batch_results[im_path] = np.concatenate([batch_results[im_path],
															 np.array(result.resize(resize_amount))],
															axis=1)

		for im_path, res in batch_results.items():
			image_name = os.path.basename(im_path)
			im_save_path = os.path.join(out_path_results, image_name)
			Image.fromarray(np.array(res)).save(im_save_path)
			global_i += 1
Beispiel #19
0
def run():
    # ===============================
    # Define the used variables
    # ===============================
    IMAGE_DISPLAY_SIZE = (512, 512)
    MODEL_INFERENCE_SIZE = (256, 256)
    IMAGE_DIR = 'demo_photo'
    TEAM_DIR = 'team'
    MODEL_CONFIG = "./configs/demo_site.yaml"

    # ==============
    # Set up model
    # ==============
    # Load the model args
    with open(MODEL_CONFIG, "r") as fp:
        opts = yaml.load(fp, Loader=yaml.FullLoader)
        opts = AttrDict(opts)

    net = model_init(opts)

    # Set up the transformer for input image
    inference_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])

    # ===============================
    # Construct demo site by streamlit
    # ===============================
    with st.sidebar:
        st.header("NTUST EdgeAI Course")
        with st.form(key="grid_reset"):
            n_photos = st.slider("Number of generate photos:", 2, 16, 8)
            n_cols = st.number_input("Number of columns", 2, 8, 4)
            mixed_alpha = st.slider("Number of mixed style(0~1)", 0, 100, 50)
            latent_code_range = st.slider('Select a range of values', 1, 18,
                                          (9, 18))
            st.form_submit_button(label="Generate new images")

    st.title('Welcome to mvclab Sketch2Real')
    st.write(" ------ ")

    # For demo columns

    # Load demo images
    demo_edge = Image.open('./demo_image/fakeface_edge.jpg')
    demo_pridiction = Image.open('./demo_image/fakeface_prediction.jpg')

    # Create demo columns on website
    demo_left_column, demo_right_column = st.beta_columns(2)
    demo_left_column.image(demo_edge, caption="Demo edge image")
    demo_right_column.image(demo_pridiction, caption="Demo generate image")

    # Create a img upload button
    uploaded_file = st.file_uploader("Choose an image...",
                                     type=["jpg", "png", "jpeg"])
    left_column, right_column = st.beta_columns(2)
    if uploaded_file is not None:
        input_image = Image.open(uploaded_file)
        left_column.image(input_image.resize(IMAGE_DISPLAY_SIZE,
                                             Image.ANTIALIAS),
                          caption="Your protrait image(edge only)")
        tensor_img = inference_transform(input_image).cuda().float().unsqueeze(
            0)
        result = run_on_batch(tensor_img, net, opts)
        result = tensor2im(result[0]).resize(IMAGE_DISPLAY_SIZE,
                                             Image.ANTIALIAS)
        right_column.image(result, caption="Generated image")
        # Create grid
        n_rows = 1 + n_photos // n_cols
        rows = [st.beta_container() for _ in range(n_rows)]
        cols_per_row = [r.beta_columns(n_cols) for r in rows]

        start_time = time.time()
        for image_index in range(n_photos):
            with rows[image_index // n_cols]:
                # Generate with alpha and latent code attributes.
                result = run_on_batch(tensor_img,
                                      net,
                                      opts,
                                      latent_code=latent_code_range,
                                      mixed_alpha=mixed_alpha / 100)
                result = tensor2im(result[0]).resize(IMAGE_DISPLAY_SIZE,
                                                     Image.ANTIALIAS)
                cols_per_row[image_index // n_cols][image_index %
                                                    n_cols].image(result)
        spend_time = time.time() - start_time
        print(
            f"Total cost time = {spend_time}, each image average cost {spend_time/n_photos}"
        )
Beispiel #20
0
def online_inference_fn(image_path, output_path):
    # Step1: 步骤1,加载图片,图片大小调整。load image and resize
    # resize image first
    # Visiual input
    # image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]
    file_name = image_path.split("/")[-1]
    original_image = Image.open(image_path)
    if opts.label_nc == 0:
        original_image = original_image.convert("RGB")
    else:
        original_image = original_image.convert("L")
    #original_image.resize((256, 256)) # resize
    original_image = original_image.resize((256, 256), resample=Image.BILINEAR)

    # Step2: 步骤2: 裁剪出人脸。run_alignment
    if experiment_type not in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
        try:
            input_image = run_alignment(image_path)
            print("run alignment error!...")
        except Exception as e:
            input_image = original_image
    else:
        input_image = original_image

    # Step 3: 运行生成程序 Perform Inference

    img_transforms = EXPERIMENT_ARGS['transform']
    transformed_image = img_transforms(input_image)

    if experiment_type in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
        latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    else:
        latent_mask = None
    with torch.no_grad():
        tic = time.time()
        result_image = run_on_batch(transformed_image.unsqueeze(0), net,
                                    latent_mask)[0]
        toc = time.time()
        print('Inference took {:.4f} seconds.'.format(toc - tic))

    # Step 6.2: Visualize Result
    input_vis_image = log_input_image(transformed_image, opts)
    output_image = tensor2im(result_image)
    if experiment_type == "celebs_super_resolution":
        res = np.concatenate([
            np.array(input_image.resize((256, 256))),
            np.array(input_vis_image.resize((256, 256))),
            np.array(output_image.resize((256, 256)))
        ],
                             axis=1)
    else:
        res = np.concatenate([
            np.array(input_vis_image.resize((256, 256))),
            np.array(output_image.resize((256, 256)))
        ],
                             axis=1)
    res_image = Image.fromarray(
        res)  # 实现array到image的转换, return an image object

    # res_image
    # save image to output path
    res_image.save(output_path + "_" + file_name + "_toon.jpg")
    def __init__(self, opts):
        self.opts = opts

        self.global_step = 0

        self.device = 'cuda:0'  # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
        self.opts.device = self.device

        # Initialize network
        self.net = pSp(self.opts).to(self.device)

        # get the image corresponding to the latent average
        self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
                                  input_code=True,
                                  randomize_noise=False,
                                  return_latents=False,
                                  average_code=True)[0]
        self.avg_image = self.avg_image.to(self.device).float().detach()
        if self.opts.dataset_type == "cars_encode":
            self.avg_image = self.avg_image[:, 32:224, :]
        common.tensor2im(self.avg_image).save(
            os.path.join(self.opts.exp_dir, 'avg_image.jpg'))

        # Initialize loss
        if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
            raise ValueError(
                'Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!'
            )
        self.mse_loss = nn.MSELoss().to(self.device).eval()
        if self.opts.lpips_lambda > 0:
            self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
        if self.opts.id_lambda > 0:
            self.id_loss = id_loss.IDLoss().to(self.device).eval()
        if self.opts.w_norm_lambda > 0:
            self.w_norm_loss = w_norm.WNormLoss(
                start_from_latent_avg=self.opts.start_from_latent_avg)
        if self.opts.moco_lambda > 0:
            self.moco_loss = moco_loss.MocoLoss()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(
                                              self.opts.test_workers),
                                          drop_last=True)

        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
    os.makedirs(out_path_results, exist_ok=True)

    # load model used for initializing encoder bootstrapping
    ckpt = torch.load(test_opts.model_1_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_1_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net1 = pSp(opts)
    else:
        net1 = e4e(opts)
    net1.eval()
    net1.cuda()

    # load model used for translating input image after initialization
    ckpt = torch.load(test_opts.model_2_checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts['checkpoint_path'] = test_opts.model_2_checkpoint_path
    opts = Namespace(**opts)
    if opts.encoder_type in ENCODER_TYPES['pSp']:
        net2 = pSp(opts)
    else:
        net2 = e4e(opts)
    net2.eval()
    net2.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = get_average_image(net1, opts)

    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size,
                                                            opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break
        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch = run_on_batch(input_cuda, net1, net2, opts,
                                        avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [
                tensor2im(result_batch[i][iter_idx])
                for iter_idx in range(opts.n_iters_per_batch + 1)
            ]
            im_path = dataset.paths[global_i]

            input_im = tensor2im(input_batch[i])

            # save step-by-step results side-by-side
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate(
                    [res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
            Image.fromarray(res).save(
                os.path.join(out_path_results, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
Beispiel #23
0
def psp(image_path=None, net=None, opts=None):
    from argparse import Namespace
    import time
    import os
    import sys
    import pprint
    import numpy as np
    from PIL import Image
    import torch
    import torchvision.transforms as transforms

    sys.path.append(".")
    sys.path.append("..")

    from datasets import augmentations
    from utils.common import tensor2im, log_input_image
    from models.psp import pSp
    import numpy as np
    from sklearn.manifold import TSNE

    import os

    CODE_DIR = 'pixel2style2pixel'

    experiment_type = 'ffhq_encode'

    def get_download_model_command(file_id, file_name):
        """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
        current_directory = os.getcwd()
        save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR,
                                 "pretrained_models")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(
            FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
        return url

    MODEL_PATHS = {
        "ffhq_encode": {
            "id": "1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0",
            "name": "psp_ffhq_encode.pt"
        },
        "ffhq_frontalize": {
            "id": "1_S4THAzXb-97DbpXmanjHtXRyKxqjARv",
            "name": "psp_ffhq_frontalization.pt"
        },
        "celebs_sketch_to_face": {
            "id": "1lB7wk7MwtdxL-LL4Z_T76DuCfk00aSXA",
            "name": "psp_celebs_sketch_to_face.pt"
        },
        "celebs_seg_to_face": {
            "id": "1VpEKc6E6yG3xhYuZ0cq8D2_1CbT0Dstz",
            "name": "psp_celebs_seg_to_face.pt"
        },
        "celebs_super_resolution": {
            "id": "1ZpmSXBpJ9pFEov6-jjQstAlfYbkebECu",
            "name": "psp_celebs_super_resolution.pt"
        },
        "toonify": {
            "id": "1YKoiVuFaqdvzDP5CZaqa3k5phL-VDmyz",
            "name": "psp_ffhq_toonify.pt"
        }
    }

    path = MODEL_PATHS[experiment_type]
    download_command = get_download_model_command(file_id=path["id"],
                                                  file_name=path["name"])

    EXPERIMENT_DATA_ARGS = {
        "ffhq_encode": {
            "model_path":
            "pretrained_models/psp_ffhq_encode.pt",
            "image_path":
            "/home/dibabdal/Desktop/MySpace/Devs/SfSNet-Pytorch-master/Images/REF_ID00353_Cam11_0063.png",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        },
        "ffhq_frontalize": {
            "model_path":
            "pretrained_models/psp_ffhq_frontalization.pt",
            "image_path":
            "notebooks/images/input_img.jpg",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        },
        "celebs_sketch_to_face": {
            "model_path":
            "pretrained_models/psp_celebs_sketch_to_face.pt",
            "image_path":
            "notebooks/images/input_sketch.jpg",
            "transform":
            transforms.Compose(
                [transforms.Resize((256, 256)),
                 transforms.ToTensor()])
        },
        "celebs_seg_to_face": {
            "model_path":
            "pretrained_models/psp_celebs_seg_to_face.pt",
            "image_path":
            "notebooks/images/input_mask.png",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                augmentations.ToOneHot(n_classes=19),
                transforms.ToTensor()
            ])
        },
        "celebs_super_resolution": {
            "model_path":
            "pretrained_models/psp_celebs_super_resolution.pt",
            "image_path":
            "notebooks/images/input_img.jpg",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                augmentations.BilinearResize(factors=[16]),
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        },
        "toonify": {
            "model_path":
            "pretrained_models/psp_ffhq_toonify.pt",
            "image_path":
            "notebooks/images/input_img.jpg",
            "transform":
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        },
    }

    EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]

    if net is None:

        model_path = EXPERIMENT_ARGS['model_path']
        ckpt = torch.load(model_path, map_location='cpu')

        opts = ckpt['opts']
        # update the training options
        opts['checkpoint_path'] = model_path
        if 'learn_in_w' not in opts:
            opts['learn_in_w'] = False

        opts = Namespace(**opts)
        net = pSp(opts)
        net.eval()
        net.cuda()
        print('Model successfully loaded!')

    if image_path is None:
        image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]

    original_image = Image.open(image_path)
    if opts.label_nc == 0:
        original_image = original_image.convert("RGB")
    else:
        original_image = original_image.convert("L")

    original_image.resize((256, 256))

    def run_alignment(image_path):
        import dlib
        from scripts.align_all_parallel import align_face
        predictor = dlib.shape_predictor(
            "shape_predictor_68_face_landmarks.dat")
        aligned_image = align_face(filepath=image_path, predictor=predictor)
        print("Aligned image has shape: {}".format(aligned_image.size))
        return aligned_image

    if experiment_type not in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
        input_image = run_alignment(image_path)
    else:
        input_image = original_image

    input_image.resize((256, 256))
    img_transforms = EXPERIMENT_ARGS['transform']
    transformed_image = img_transforms(input_image)

    def run_on_batch(inputs, net, latent_mask=None):
        latent = None
        if latent_mask is None:
            result_batch, latent = net(inputs.to("cuda").float(),
                                       randomize_noise=False,
                                       return_latents=True)

        else:
            result_batch = []
            for image_idx, input_image in enumerate(inputs):
                # get latent vector to inject into our input image
                vec_to_inject = np.random.randn(1, 512).astype('float32')
                _, latent_to_inject = net(
                    torch.from_numpy(vec_to_inject).to("cuda"),
                    input_code=True,
                    return_latents=True)

                # get output image with injected style vector
                res = net(input_image.unsqueeze(0).to("cuda").float(),
                          latent_mask=latent_mask,
                          inject_latent=latent_to_inject)
                result_batch.append(res)
            result_batch = torch.cat(result_batch, dim=0)
        return result_batch, latent

    if experiment_type in ["celebs_sketch_to_face", "celebs_seg_to_face"]:
        latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    else:
        latent_mask = None

    with torch.no_grad():
        tic = time.time()
        result_image, latent = run_on_batch(transformed_image.unsqueeze(0),
                                            net, latent_mask)
        result_image = result_image[0]
        toc = time.time()
        print('Inference took {:.4f} seconds.'.format(toc - tic))

    input_vis_image = log_input_image(transformed_image, opts)
    output_image = tensor2im(result_image)

    if experiment_type == "celebs_super_resolution":
        res = np.concatenate([
            np.array(input_image.resize((256, 256))),
            np.array(input_vis_image.resize((256, 256))),
            np.array(output_image.resize((256, 256)))
        ],
                             axis=1)
    else:
        res = np.concatenate([
            np.array(input_vis_image.resize((256, 256))),
            np.array(output_image.resize((256, 256)))
        ],
                             axis=1)

    res_image = Image.fromarray(res)
    import gc
    gc.collect()
    return res_image, latent, net, opts
Beispiel #24
0
def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
    out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
    os.makedirs(out_path_results, exist_ok=True)
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    age_transformers = [
        AgeTransformer(target_age=age) for age in opts.target_age.split(',')
    ]

    print(f'Loading dataset for {opts.dataset_type}')
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_time = []
    for age_transformer in age_transformers:
        print(f"Running on target age: {age_transformer.target_age}")
        global_i = 0
        for input_batch in tqdm(dataloader):
            if global_i >= opts.n_images:
                break
            with torch.no_grad():
                input_age_batch = [
                    age_transformer(img.cpu()).to('cuda')
                    for img in input_batch
                ]
                input_age_batch = torch.stack(input_age_batch)
                input_cuda = input_age_batch.cuda().float()
                tic = time.time()
                result_batch = run_on_batch(input_cuda, net, opts)
                toc = time.time()
                global_time.append(toc - tic)

                for i in range(len(input_batch)):
                    result = tensor2im(result_batch[i])
                    im_path = dataset.paths[global_i]

                    if opts.couple_outputs or global_i % 100 == 0:
                        input_im = log_image(input_batch[i], opts)
                        resize_amount = (
                            256, 256) if opts.resize_outputs else (1024, 1024)
                        res = np.concatenate([
                            np.array(input_im.resize(resize_amount)),
                            np.array(result.resize(resize_amount))
                        ],
                                             axis=1)
                        age_out_path_coupled = os.path.join(
                            out_path_coupled, age_transformer.target_age)
                        os.makedirs(age_out_path_coupled, exist_ok=True)
                        Image.fromarray(res).save(
                            os.path.join(age_out_path_coupled,
                                         os.path.basename(im_path)))

                    age_out_path_results = os.path.join(
                        out_path_results, age_transformer.target_age)
                    os.makedirs(age_out_path_results, exist_ok=True)
                    image_name = os.path.basename(im_path)
                    im_save_path = os.path.join(age_out_path_results,
                                                image_name)
                    Image.fromarray(np.array(
                        result.resize(resize_amount))).save(im_save_path)
                    global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time),
                                                 np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
def run():
    test_opts = TestOptions().parse()

    if test_opts.resize_factors is not None:
        factors = test_opts.resize_factors.split(',')
        assert len(
            factors
        ) == 1, "When running inference, please provide a single downsampling factor!"
        mixed_path_results = os.path.join(
            test_opts.exp_dir, 'style_mixing',
            'downsampling_{}'.format(test_opts.resize_factors))
    else:
        mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing')
    os.makedirs(mixed_path_results, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    if 'learn_in_w' not in opts:
        opts['learn_in_w'] = False
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(
        root=opts.data_path,
        transform=transforms_dict['transform_inference'],
        opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=True)

    latent_mask = [int(l) for l in opts.latent_mask.split(",")]
    if opts.n_images is None:
        opts.n_images = len(dataset)

    global_i = 0
    for input_batch in tqdm(dataloader):
        if global_i > opts.n_images:
            break
        with torch.no_grad():
            input_batch = input_batch.cuda()
            for image_idx, input_image in enumerate(input_batch):
                # generate random vectors to inject into input image
                vecs_to_inject = np.random.randn(opts.n_outputs_to_generate,
                                                 512).astype('float32')
                multi_modal_outputs = []
                for vec_to_inject in vecs_to_inject:
                    cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to(
                        "cuda")
                    # get latent vector to inject into our input image
                    _, latent_to_inject = net(cur_vec,
                                              input_code=True,
                                              return_latents=True)
                    # get output image with injected style vector
                    res = net(input_image.unsqueeze(0).to("cuda").float(),
                              latent_mask=latent_mask,
                              inject_latent=latent_to_inject,
                              alpha=opts.mix_alpha)
                    multi_modal_outputs.append(res[0])

                # visualize multi modal outputs
                input_im_path = dataset.paths[global_i]
                image = input_batch[image_idx]
                input_image = log_input_image(image, opts)
                res = np.array(input_image.resize((256, 256)))
                for output in multi_modal_outputs:
                    output = tensor2im(output)
                    res = np.concatenate(
                        [res, np.array(output.resize((256, 256)))], axis=1)
                Image.fromarray(res).save(
                    os.path.join(mixed_path_results,
                                 os.path.basename(input_im_path)))
                global_i += 1
def run():
    test_opts = TestOptions().parse()

    out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
    os.makedirs(out_path_coupled, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    print('Loading dataset for {}'.format(opts.dataset_type))
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    dataset = InferenceDataset(root=opts.data_path,
                               transform=transforms_dict['transform_inference'],
                               opts=opts)
    dataloader = DataLoader(dataset,
                            batch_size=opts.test_batch_size,
                            shuffle=False,
                            num_workers=int(opts.test_workers),
                            drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(dataset)

    # get the image corresponding to the latent average
    avg_image = net(net.latent_avg.unsqueeze(0),
                    input_code=True,
                    randomize_noise=False,
                    return_latents=False,
                    average_code=True)[0]
    avg_image = avg_image.to('cuda').float().detach()
    if opts.dataset_type == "cars_encode":
        avg_image = avg_image[:, 32:224, :]
    tensor2im(avg_image).save(os.path.join(opts.exp_dir, 'avg_image.jpg'))

    if opts.dataset_type == "cars_encode":
        resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
    else:
        resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

    global_i = 0
    global_time = []
    for input_batch in tqdm(dataloader):
        if global_i >= opts.n_images:
            break

        with torch.no_grad():
            input_cuda = input_batch.cuda().float()
            tic = time.time()
            result_batch, result_latents = run_on_batch(input_cuda, net, opts, avg_image)
            toc = time.time()
            global_time.append(toc - tic)

        for i in range(input_batch.shape[0]):
            results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)]
            im_path = dataset.paths[global_i]

            # save step-by-step results side-by-side
            input_im = tensor2im(input_batch[i])
            res = np.array(results[0].resize(resize_amount))
            for idx, result in enumerate(results[1:]):
                res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
            res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)

            Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))

            global_i += 1

    stats_path = os.path.join(opts.exp_dir, 'stats.txt')
    result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
    print(result_str)

    with open(stats_path, 'w') as f:
        f.write(result_str)
Beispiel #27
0
    def __init__(self, opts, prev_train_checkpoint=None):
        self.opts = opts

        self.global_step = 0

        self.device = 'cuda:0'
        self.opts.device = self.device

        # Initialize network
        self.net = e4e(self.opts).to(self.device)

        # Estimate latent_avg via dense sampling if latent_avg is not available
        if self.net.latent_avg is None:
            self.net.latent_avg = self.net.decoder.mean_latent(
                int(1e5))[0].detach()

        # get the image corresponding to the latent average
        self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
                                  input_code=True,
                                  randomize_noise=False,
                                  return_latents=False,
                                  average_code=True)[0]
        self.avg_image = self.avg_image.to(self.device).float().detach()
        if self.opts.dataset_type == "cars_encode":
            self.avg_image = self.avg_image[:, 32:224, :]
        common.tensor2im(self.avg_image).save(
            os.path.join(self.opts.exp_dir, 'avg_image.jpg'))

        # Initialize loss
        if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
            raise ValueError(
                'Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!'
            )
        self.mse_loss = nn.MSELoss().to(self.device).eval()
        if self.opts.lpips_lambda > 0:
            self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
        if self.opts.id_lambda > 0:
            self.id_loss = id_loss.IDLoss().to(self.device).eval()
        if self.opts.moco_lambda > 0:
            self.moco_loss = moco_loss.MocoLoss()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize discriminator
        if self.opts.w_discriminator_lambda > 0:
            self.discriminator = LatentCodesDiscriminator(512,
                                                          4).to(self.device)
            self.discriminator_optimizer = torch.optim.Adam(
                list(self.discriminator.parameters()),
                lr=opts.w_discriminator_lr)
            self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
            self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(
                                              self.opts.test_workers),
                                          drop_last=True)

        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps

        if prev_train_checkpoint is not None:
            self.load_from_train_checkpoint(prev_train_checkpoint)
            prev_train_checkpoint = None