Beispiel #1
0
    def sample(self, loaders):
        args = self.args
        nets_ema = self.nets_ema
        os.makedirs(args.result_dir, exist_ok=True)
        self._load_checkpoint(args.resume_iter)

        src = next(InputFetcher(loaders.src, None, args.latent_dim, 'test'))
        ref = next(InputFetcher(loaders.ref, None, args.latent_dim, 'test'))

        fname = ospj(args.result_dir, 'reference.jpg')
        print('Working on {}...'.format(fname))
        utils.translate_using_reference(nets_ema, args, src.x, ref.x, ref.y,
                                        fname)

        fname = ospj(args.result_dir, 'video_ref.mp4')
        print('Working on {}...'.format(fname))
        utils.video_ref(nets_ema, args, src.x, ref.x, ref.y, fname)

        N = src.x.size(0)

        y_trg_list = [
            torch.tensor(y).repeat(N).to(device)
            for y in range(min(args.num_domains, 5))
        ]
        z_trg_list = torch.randn(args.num_outs_per_domain, 1,
                                 args.latent_dim).repeat(1, N, 1).to(device)
        for psi in [0.5, 0.7, 1.0]:
            filename = ospj(args.sample_dir,
                            '%06d_latent_psi_%.1f.jpg' % (step, psi))
            translate_using_latent(nets, args, src.x, y_trg_list, z_trg_list,
                                   psi, fname)

        fname = ospj(args.result_dir, 'latent.jpg')
        print('Working on {}...'.format(fname))
        utils.video_ref(nets_ema, args, src.x, ref.x, ref.y, fname)
Beispiel #2
0
    def sample(self, loaders):
        args = self.args
        nets_ema = self.nets_ema
        os.makedirs(args.result_dir, exist_ok=True)
        self._load_checkpoint(args.resume_iter)

        src = next(InputFetcher(loaders.src, None, args.latent_dim, 'test'))
        ref = next(InputFetcher(loaders.ref, None, args.latent_dim, 'test'))

        fname = ospj(args.result_dir, 'reference.jpg')
        print('Working on {}...'.format(fname))
        utils.translate_using_reference(nets_ema, args, src.x, ref.x, ref.y, fname)
Beispiel #3
0
    def custom(self, loaders):
        args = self.args
        nets_ema = self.nets_ema
        self._load_checkpoint(100000)

        src = next(InputFetcher(loaders.src, None, args.latent_dim, 'test'))
        ref = next(InputFetcher(loaders.ref, None, args.latent_dim, 'test'))

        fname = args.custom_out_img
        print('Working on {}...'.format(fname))
        utils.translate_using_reference(nets_ema, args, src.x, ref.x, ref.y,
                                        fname)
Beispiel #4
0
def calculate_metrics(nets, args, step, mode):
    print('Calculating evaluation metrics...')
    #assert mode in ['latent', 'reference']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    domains = [n for n in range(args.num_domains)]
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)

    for trg_idx, trg_domain in enumerate(domains):
        task = '%s' % trg_domain
        path_fake = os.path.join(args.eval_dir, task)
        shutil.rmtree(path_fake, ignore_errors=True)
        os.makedirs(path_fake)

        loader = get_sample_loader(root=args.val_img_dir,
                                   img_size=args.img_size,
                                   batch_size=args.val_batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   drop_last=False,
                                   trg_domain=trg_domain,
                                   mode=mode,
                                   dataset_dir=args.dataset_dir,
                                   threshold=args.num_sample)

        fetcher = InputFetcher(loader, None, args.latent_dim, 'test')

        print('Generating images for %s...' % task)

        for i in tqdm(range(len(loader))):
            # fetch images and labels
            inputs = next(fetcher)
            x_src, x_ref, y = inputs.src, inputs.trg, inputs.y
            N = x_src.size(0)
            x_src = x_src.to(device)
            x_ref = x_ref.to(device)
            y_trg = torch.tensor([trg_idx] * N).to(device)

            masks = None

            s_trg = nets.style_encoder(x_ref, y_trg)

            x_fake = nets.generator(x_src, s_trg, masks=masks)

            # save generated images to calculate FID later
            for k in range(N):
                filename = os.path.join(
                    path_fake,
                    '%.4i.png' % (i * args.val_batch_size + (k + 1)))
                utils.save_image(x_fake[k], ncol=1, filename=filename)

    # calculate and report fid values
    fid_values, fid_mean = calculate_fid_for_all_tasks(
        args, domains, step=step, mode=mode, dataset_dir=args.dataset_dir)
    return fid_values, fid_mean
    def sample(self, loaders):
        args = self.args
        nets_ema = self.nets_ema
        os.makedirs(args.result_dir, exist_ok=True)
        self._load_checkpoint(args.resume_iter)
        fetch_src = InputFetcher(loaders.src, None, args.latent_dim, 'test')
        fetch_ref = InputFetcher(loaders.ref, None, args.latent_dim, 'test')
        for i in range(10000):
            src = next(fetch_src)
            ref = next(fetch_ref)

            fname = ospj(
                args.result_dir,
                str(i).zfill(5) + "from" + str(src.y.item()) + "to" +
                str(ref.y.item()) + '.jpg')
            print('Working on {}...'.format(fname))
            utils.translate_using_reference(nets_ema, args, src.x, src.y,
                                            ref.x, ref.y, fname)
Beispiel #6
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
                                        ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)
                wandb.log({"Losses":all_losses})

            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')
Beispiel #7
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim,
                               'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trg=z_trg,
                                                     masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_ref=x_ref,
                                                  masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trgs=[z_trg, z_trg2],
                                                     masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_refs=[x_ref, x_ref2],
                                                  masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network,
                           nets_ema.mapping_network,
                           beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([
                        d_losses_latent, d_losses_ref, g_losses_latent,
                        g_losses_ref
                ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')

            import adaiw
            if (i + 1) % args.print_learned == 0:
                alphas_list = []
                for module in self.nets.generator.decode:
                    if isinstance(module.norm1, adaiw.AdaIN):
                        alphas_list.append([
                            ('norm_1_white', module.norm1.alpha_white_param,
                             'norm_1_color', module.norm1.alpha_color_param),
                            ('norm_2_white', module.norm2.alpha_white_param,
                             'norm_2_color', module.norm2.alpha_color_param)
                        ])

                txt_f = os.path.join(args.notes_path, "alphas.txt")
                with open(txt_f, "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for item in alphas_list:
                        f.write("{}\n".format(item))

            if (i + 1) % args.print_std == 0:
                with open(os.path.join(args.notes_path, "std.txt"), "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for j, module in enumerate(self.nets.generator.decode):
                        print([
                            module.std_b4_norm_1.item(),
                            module.std_b4_norm_2.item(),
                            module.std_b4_join.item(),
                            module.std_b4_output.item()
                        ],
                              file=f)

            if (i + 1) % args.print_sqrt_error == 0:
                with open(os.path.join(args.notes_path, "sqrt_errors.txt"),
                          "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    all_errors = []
                    for j, module in enumerate(self.nets.generator.decode):
                        errors = []
                        for norm in [module.norm1, module.norm2]:
                            if isinstance(norm, adaiw.AdaIN) and isinstance(
                                    norm.normalizer,
                                    adaiw.normalizer.Whitening):
                                errors.append(norm.normalizer.last_error)
                        all_errors.append(tuple(errors))
                    print(all_errors, file=f)

            if (i + 1) % args.print_color == 0:
                with torch.no_grad():
                    with open(os.path.join(args.notes_path, "last_color.txt"),
                              "a+") as f:
                        f.write("{} iterations: \n".format(i + 1))
                        for _, module in enumerate(self.nets.generator.decode):
                            if hasattr(module.norm1, 'last_injected_stat'):
                                diag1, tri1 = module.norm1.last_injected_stat.diagonal(
                                ), module.norm1.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag1[0].mean().item(),
                                      "Std Diag:",
                                      diag1[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri1[0].mean().item(),
                                      "Std Tril:",
                                      tri1[0].std().item(),
                                      file=f)
                                diag2, tri2 = module.norm2.last_injected_stat.diagonal(
                                ), module.norm2.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag2[0].mean().item(),
                                      "Std Diag:",
                                      diag2[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri2[0].mean().item(),
                                      "Std Tril:",
                                      tri2[0].std().item(),
                                      file=f)
                                print(file=f)
Beispiel #8
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        x_fixed = next(fetcher_val)
        x_fixed = x_fixed.x_src

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()

        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            """
            Removing Reference based training
            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()
            """

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            """
            Removing reference based training
            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            """

            # compute moving average of network parameters
            """
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
            """

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, g_losses_latent],
                                        ['D/latent_', 'G/latent_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)

            """
            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)
            """
            if (i+1) % args.sample_every == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for j in range(args.num_domains):
                        label = torch.ones((x_fixed.size(0),),dtype=torch.long).to(self.device)
                        label = label*j
                        z = torch.randn((x_fixed.size(0),args.latent_dim)).to(self.device)
                        style = self.nets.mapping_network(z,label)
                        x_fake_list.append(self.nets.generator(x_fixed, style))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join('samples', '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')