Beispiel #1
0
    def forward(
        self,
        ref_im,
        loss_str,
        eps,
        noise_type,
        num_trainable_noise_layers,
        tile_latent,
        bad_noise_layers,
        opt_name,
        learning_rate,
        steps,
        lr_schedule,
        save_intermediate,
        seed=0,
        var_list_initial_values=None,
        step_postprocess=None,
        psi=1.0,
        **kwargs,
    ):

        if seed:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        batch_size = ref_im.shape[0]

        # Generate latent tensor
        if tile_latent:
            latent = torch.randn(
                (batch_size, 1, 512),
                dtype=torch.float,
                requires_grad=True,
                device="cuda",
            )
        else:
            latent = torch.randn(
                (batch_size, 18, 512),
                dtype=torch.float,
                requires_grad=True,
                device="cuda",
            )
        with torch.no_grad():
            latent *= psi

        # Generate list of noise tensors
        noise = []  # stores all of the noise tensors
        noise_vars = []  # stores the noise tensors that we want to optimize on

        for i in range(18):
            # dimension of the ith noise tensor
            res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2))

            if noise_type == "zero" or i in [
                    int(layer) for layer in bad_noise_layers.split(".")
            ]:
                new_noise = torch.zeros(res, dtype=torch.float, device="cuda")
                new_noise.requires_grad = False
            elif noise_type == "fixed":
                new_noise = torch.randn(res, dtype=torch.float, device="cuda")
                new_noise.requires_grad = False
            elif noise_type == "trainable":
                new_noise = torch.randn(res, dtype=torch.float, device="cuda")
                if i < num_trainable_noise_layers:
                    new_noise.requires_grad = True
                    noise_vars.append(new_noise)
                else:
                    new_noise.requires_grad = False
            else:
                raise Exception("unknown noise type")

            noise.append(new_noise)

        var_list = [latent] + noise_vars

        if var_list_initial_values is not None:
            assert len(var_list) == len(var_list_initial_values)
            with torch.no_grad():
                for var, initial_value in zip(var_list,
                                              var_list_initial_values):
                    var.copy_(initial_value)

        opt_dict = {
            "sgd": torch.optim.SGD,
            "adam": torch.optim.Adam,
            "sgdm": partial(torch.optim.SGD, momentum=0.9),
            "adamax": torch.optim.Adamax,
            "custom": partial(torch.optim.AdamW, betas=(0.9, 0.99)),
        }
        opt_func = opt_dict[opt_name]
        if step_postprocess is None:
            opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
        else:
            opt = StepPostProcessOptimizer(opt_func,
                                           var_list,
                                           step_postprocess=step_postprocess,
                                           lr=learning_rate)

        schedule_dict = {
            "fixed":
            lambda x: 1,
            "linear1cycle":
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            "linear1cycledrop":
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        if self.verbose:
            print("Optimizing")
        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if tile_latent:
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] +
                                   self.gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesis(latent_in, noise) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict["TOTAL"] = loss

            # Save best summary for log
            if loss < min_loss:
                min_loss = loss
                best_summary = f"BEST ({j+1}) | " + " | ".join(
                    [f"{x}: {y:.4f}" for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict["L2"]

            if loss_l2 < min_l2:
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if save_intermediate:
                yield dict(
                    final=False,
                    min_l2=min_l2,
                    HR=best_im.cpu().detach().clamp(0, 1),
                    LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1),
                )

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f" | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}"
        if self.verbose:
            print(best_summary + current_info)
        yield dict(
            final=True,
            min_l2=min_l2,
            success=min_l2 <= eps,
            HR=gen_im.clone().cpu().detach().clamp(0, 1),
            LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1),
            var_list=var_list,
            loss_dict=loss_dict,
        )
Beispiel #2
0
    def forward(self, ref_im, seed, loss_str, eps, noise_type,
                num_trainable_noise_layers, tile_latent, bad_noise_layers,
                opt_name, learning_rate, steps, lr_schedule, save_intermediate,
                **kwargs):

        if seed:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        batch_size = ref_im.shape[0]

        # Generate latent tensor
        if (tile_latent):
            latent = torch.randn((batch_size, 1, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')
        else:
            latent = torch.randn((batch_size, 18, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')

        # Generate list of noise tensors
        noise = []  # stores all of the noise tensors
        noise_vars = []  # stores the noise tensors that we want to optimize on

        for i in range(18):
            # dimension of the ith noise tensor
            res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2))

            if (noise_type == 'zero' or i
                    in [int(layer) for layer in bad_noise_layers.split('.')]):
                new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
                new_noise.requires_grad = False
            elif (noise_type == 'fixed'):
                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
                new_noise.requires_grad = False
            elif (noise_type == 'trainable'):
                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
                if (i < num_trainable_noise_layers):
                    new_noise.requires_grad = True
                    noise_vars.append(new_noise)
                else:
                    new_noise.requires_grad = False
            else:
                raise Exception("unknown noise type")

            noise.append(new_noise)

        var_list = [latent] + noise_vars

        opt_dict = {
            'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam,
            'sgdm': partial(torch.optim.SGD, momentum=0.9),
            'adamax': torch.optim.Adamax
        }
        opt_func = opt_dict[opt_name]
        opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)

        schedule_dict = {
            'fixed':
            lambda x: 1,
            'linear1cycle':
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            'linear1cycledrop':
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        if self.verbose: print("Optimizing")
        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if (tile_latent):
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] +
                                   self.gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesis(latent_in, noise) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict['TOTAL'] = loss

            # Save best summary for log
            if (loss < min_loss):
                min_loss = loss
                best_summary = f'BEST ({j+1}) | ' + ' | '.join(
                    [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict['L2']

            if (loss_l2 < min_l2):
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if (save_intermediate):
                yield (best_im.cpu().detach().clamp(0, 1),
                       loss_builder.D(best_im).cpu().detach().clamp(0, 1))

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
        if self.verbose: print(best_summary + current_info)
        if (min_l2 <= eps):
            yield (gen_im.clone().cpu().detach().clamp(0, 1),
                   loss_builder.D(best_im).cpu().detach().clamp(0, 1))
        else:
            print(
                "Could not find a face that downscales correctly within epsilon"
            )
Beispiel #3
0
    def forward(self, ref_im, loss_str, eps, tile_latent, opt_name, steps,
                learning_rate, lr_schedule, save_intermediate):

        gaussian_fit = self.Gaussian_fit()

        batch_size = ref_im.shape[0]

        if (tile_latent):
            latent = torch.randn((batch_size, 1, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')
        else:
            latent = torch.randn((batch_size, 18, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')

        var_list = [latent]

        opt_dict = {
            'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam,
            'sgdm': partial(torch.optim.SGD, momentum=0.9),
            'adamax': torch.optim.Adamax
        }
        opt_func = opt_dict[opt_name]
        opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)

        schedule_dict = {
            'fixed':
            lambda x: 1,
            'linear1cycle':
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            'linear1cycledrop':
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        ref_im = ref_im.cuda()

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        print("Optimizing")

        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if (tile_latent):
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * gaussian_fit["std"] +
                                   gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesizer(latent_in) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict['TOTAL'] = loss

            # Save best summary for log
            if (loss < min_loss):
                min_loss = loss
                best_summary = f'BEST ({j+1}) | ' + ' | '.join(
                    [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict['L2']

            if (loss_l2 < min_l2):
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if (save_intermediate):
                yield (best_im.cpu().detach().clamp(0, 1),
                       loss_builder.D(best_im).cpu().detach().clamp(0, 1))

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
        print(best_summary + current_info)
        if (min_l2 <= eps):
            yield (gen_im.clone().cpu().detach().clamp(0, 1),
                   loss_builder.D(best_im).cpu().detach().clamp(0, 1))
        else:
            print(
                "Could not find a face that downscales correctly within epsilon"
            )