Beispiel #1
0
def eval(opt, netG):
    # Re-generate dataset frames

    fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt)
    opt.fps = fps
    opt.td = td
    opt.fps_index = fps_index
    # opt.tds.append(opt.td)
    opt.dataset.generate_frames(opt.scale_idx)

    torch.save(opt.dataset.frames,
               os.path.join(opt.saver.eval_dir, "real_full_scale.pth"))

    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [
            opt.batch_size, opt.latent_dim, opt.td, *initial_size
        ]

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Generation scale [{}/{}]".format(opt.scale_idx + 1,
                                          opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    random_samples = []

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
        else:
            real = data.to(opt.device)

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))

        with torch.no_grad():
            fake_var = []
            fake_vae_var = []
            for _ in range(opt.num_samples):
                noise_init = utils.generate_noise(ref=noise_init)
                fake, fake_vae = G_curr(noise_init,
                                        opt.Noise_Amps,
                                        noise_init=noise_init,
                                        mode="rand")
                fake_var.append(fake)
                fake_vae_var.append(fake_vae)
            fake_var = torch.cat(fake_var, dim=0)
            fake_vae_var = torch.cat(fake_vae_var, dim=0)

        opt.summary.visualize_video(opt, iteration, real, 'Real')
        opt.summary.visualize_video(opt, iteration, fake_var, 'Fake var')
        opt.summary.visualize_video(opt, iteration, fake_vae_var,
                                    'Fake VAE var')

        random_samples.append(fake_var)

    random_samples = torch.cat(random_samples, dim=0)
    torch.save(random_samples,
               os.path.join(opt.saver.eval_dir, "random_samples.pth"))
    epoch_iterator.close()
Beispiel #2
0
    parser.set_defaults(hflip=False)
    opt = parser.parse_args()

    exceptions = ['no-cuda', 'niter', 'data_rep', 'batch_size', 'netG']
    all_dirs = glob(opt.exp_dir)

    progressbar_args = {
        "iterable": all_dirs,
        "desc": "Experiments",
        "train": True,
        "offset": 0,
        "logging_on_update": False,
        "logging_on_close": True,
        "postfix": True
    }
    exp_iterator = tools.create_progressbar(**progressbar_args)

    for idx, exp_dir in enumerate(exp_iterator):
        opt.experiment_dir = exp_dir
        keys = vars(opt).keys()
        with open(os.path.join(exp_dir, 'args.txt'), 'r') as f:
            for line in f.readlines():
                log_arg = line.replace(' ', '').replace('\n', '').split(':')
                assert len(log_arg) == 2
                if log_arg[0] in exceptions:
                    continue
                try:
                    setattr(opt, log_arg[0], ast.literal_eval(log_arg[1]))
                except Exception:
                    setattr(opt, log_arg[0], log_arg[1])
Beispiel #3
0
def train(opt, netG):
    # Re-generate dataset frames
    fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt)
    opt.fps = fps
    opt.td = td
    opt.fps_index = fps_index

    with logger.LoggingBlock("Updating dataset", emph=True):
        logging.info("{}FPS :{} {}{}".format(green, clear, opt.fps, clear))
        logging.info("{}Time-Depth :{} {}{}".format(green, clear, opt.td,
                                                    clear))
        logging.info("{}Sampling-Ratio :{} {}{}".format(
            green, clear, opt.sampling_rates[opt.fps_index], clear))
        opt.dataset.generate_frames(opt.scale_idx)

    # Initialize noise
    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [
            opt.batch_size, opt.latent_dim, opt.td, *initial_size
        ]

    if opt.vae_levels < opt.scale_idx + 1:
        D_curr = getattr(networks_3d, opt.discriminator)(opt).to(opt.device)

        if (opt.netG != '') and (opt.resumed_idx == opt.scale_idx):
            D_curr.load_state_dict(
                torch.load('{}/netD_{}.pth'.format(
                    opt.resume_dir, opt.scale_idx - 1))['state_dict'])
        elif opt.vae_levels < opt.scale_idx:
            D_curr.load_state_dict(
                torch.load(
                    '{}/netD_{}.pth'.format(opt.saver.experiment_dir,
                                            opt.scale_idx - 1))['state_dict'])

        # Current optimizers
        optimizerD = optim.Adam(D_curr.parameters(),
                                lr=opt.lr_d,
                                betas=(opt.beta1, 0.999))

    parameter_list = []
    # Generator Adversary
    if not opt.train_all:
        if opt.vae_levels < opt.scale_idx + 1:
            train_depth = min(opt.train_depth,
                              len(netG.body) - opt.vae_levels + 1)
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-train_depth:])]
        else:
            # VAE
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]
    else:
        if len(netG.body) < opt.train_depth:
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g * (opt.lr_scale**(len(netG.body) - 1 - idx))
            } for idx, block in enumerate(netG.body)]
        else:
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]

    optimizerG = optim.Adam(parameter_list,
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
        if opt.vae_levels < opt.scale_idx + 1:
            D_curr = torch.nn.DataParallel(D_curr)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
            real_zero = real_zero.to(opt.device)
        else:
            real = data.to(opt.device)
            real_zero = real

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        ############################
        # calculate noise_amp
        ###########################
        if iteration == 0:
            if opt.const_amp:
                opt.Noise_Amps.append(1)
            else:
                with torch.no_grad():
                    if opt.scale_idx == 0:
                        opt.noise_amp = 1
                        opt.Noise_Amps.append(opt.noise_amp)
                    else:
                        opt.Noise_Amps.append(0)
                        z_reconstruction, _, _ = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

                        RMSE = torch.sqrt(F.mse_loss(real, z_reconstruction))
                        opt.noise_amp = opt.noise_amp_init * RMSE.item(
                        ) / opt.batch_size
                        opt.Noise_Amps[-1] = opt.noise_amp

        ############################
        # (1) Update VAE network
        ###########################
        total_loss = 0

        generated, generated_vae, (mu, logvar) = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

        if opt.vae_levels >= opt.scale_idx + 1:
            rec_vae_loss = opt.rec_loss(generated, real) + opt.rec_loss(
                generated_vae, real_zero)
            kl_loss = kl_criterion(mu, logvar)
            vae_loss = opt.rec_weight * rec_vae_loss + opt.kl_weight * kl_loss

            total_loss += vae_loss
        else:
            ############################
            # (2) Update D network: maximize D(x) + D(G(z))
            ###########################
            # train with real
            #################

            # Train 3D Discriminator
            D_curr.zero_grad()
            output = D_curr(real)
            errD_real = -output.mean()

            # train with fake
            #################
            fake, _ = G_curr(noise_init,
                             opt.Noise_Amps,
                             noise_init=noise_init,
                             mode="rand")

            # Train 3D Discriminator
            output = D_curr(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = calc_gradient_penalty(D_curr, real, fake,
                                                     opt.lambda_grad,
                                                     opt.device)
            errD_total = errD_real + errD_fake + gradient_penalty
            errD_total.backward()
            optimizerD.step()

            ############################
            # (3) Update G network: maximize D(G(z))
            ###########################
            errG_total = 0
            rec_loss = opt.rec_loss(generated, real)
            errG_total += opt.rec_weight * rec_loss

            # Train with 3D Discriminator
            output = D_curr(fake)
            errG = -output.mean() * opt.disc_loss_weight
            errG_total += errG

            total_loss += errG_total

        G_curr.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(G_curr.parameters(), opt.grad_clip)
        optimizerG.step()

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))

        if opt.visualize:
            # Tensorboard
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels >= opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/KLD'.format(opt.scale_idx), kl_loss.item(),
                    iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/rec loss'.format(opt.scale_idx),
                    rec_loss.item(), iteration)
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels < opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/errG'.format(opt.scale_idx), errG.item(),
                    iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_fake'.format(opt.scale_idx),
                    errD_fake.item(), iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_real'.format(opt.scale_idx),
                    errD_real.item(), iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/Rec VAE'.format(opt.scale_idx),
                    rec_vae_loss.item(), iteration)

            if iteration % opt.print_interval == 0:
                with torch.no_grad():
                    fake_var = []
                    fake_vae_var = []
                    for _ in range(3):
                        noise_init = utils.generate_noise(ref=noise_init)
                        fake, fake_vae = G_curr(noise_init,
                                                opt.Noise_Amps,
                                                noise_init=noise_init,
                                                mode="rand")
                        fake_var.append(fake)
                        fake_vae_var.append(fake_vae)
                    fake_var = torch.cat(fake_var, dim=0)
                    fake_vae_var = torch.cat(fake_vae_var, dim=0)

                opt.summary.visualize_video(opt, iteration, real, 'Real')
                opt.summary.visualize_video(opt, iteration, generated,
                                            'Generated')
                opt.summary.visualize_video(opt, iteration, generated_vae,
                                            'Generated VAE')
                opt.summary.visualize_video(opt, iteration, fake_var,
                                            'Fake var')
                opt.summary.visualize_video(opt, iteration, fake_vae_var,
                                            'Fake VAE var')

    epoch_iterator.close()

    # Save data
    opt.saver.save_checkpoint({'data': opt.Noise_Amps}, 'Noise_Amps.pth')
    opt.saver.save_checkpoint(
        {
            'scale': opt.scale_idx,
            'state_dict': netG.state_dict(),
            'optimizer': optimizerG.state_dict(),
            'noise_amps': opt.Noise_Amps,
        }, 'netG.pth')
    if opt.vae_levels < opt.scale_idx + 1:
        opt.saver.save_checkpoint(
            {
                'scale':
                opt.scale_idx,
                'state_dict':
                D_curr.module.state_dict()
                if opt.device == 'cuda' else D_curr.state_dict(),
                'optimizer':
                optimizerD.state_dict(),
            }, 'netD_{}.pth'.format(opt.scale_idx))
Beispiel #4
0
def eval(opt, netG):
    # Re-generate dataset frames

    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [opt.batch_size, opt.latent_dim, *initial_size]

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    random_samples = []

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
        else:
            real = data.to(opt.device)

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))
        G_curr.eval()
        import numpy as np
        import sys
        with torch.no_grad():
            fake_var = []
            fake_vae_var = []
            for _ in range(opt.num_samples):
                noise_init = utils.generate_noise(ref=noise_init)
                channel_idxs = np.random.choice(np.arange(0, 128),
                                                127,
                                                replace=False)
                # U = torch.zeros(1, 128, 5).normal_(0, 1).to(noise_init.device)
                U = torch.zeros(1, 128, 1).to(noise_init.device)
                U[:, _] = 4
                # U[:, :120] =
                V = torch.zeros(1, 1, 22, 33).to(noise_init.device)
                # V.bernoulli_(p=0.01)
                V[:, :, 1:4, 20:32] = 1
                # V[:, :, 4:10, 8:10] = 1
                V = V.flatten(2)
                UV = torch.bmm(U, V).view(1, 128, 22, 33)
                UV = (UV - UV.mean()) / UV.std()
                # noise_init[:] = 0
                # noise_init[:, :, 5:11, 16:18] = _
                # noise_init[:, 108, 0:4, 0:4] = 100
                # noise_init[:, 21, _:_ + 1, 16:19] = 0.01
                # noise_init[:, :, 3:11, 16:18] = -10 / opt.num_samples

                # normed_z_vae = z_vae / ((z_vae ** 2).sum() + sys.float_info.epsilon)
                # noise_init = noise_init / ((noise_init ** 2).sum() + sys.float_info.epsilon)
                noise_init = UV
                fake, fake_vae = G_curr(noise_init,
                                        opt.Noise_Amps,
                                        noise_init=noise_init,
                                        mode="rand")
                fake_var.append(fake)
                fake_vae_var.append(fake_vae)
            fake_var = torch.cat(fake_var, dim=0)
            fake_vae_var = torch.cat(fake_vae_var, dim=0)

        opt.summary.visualize_image(opt, iteration, real, 'Real')
        opt.summary.visualize_image(opt, iteration, fake_var, 'Fake var')
        opt.summary.visualize_image(opt, iteration, fake_vae_var,
                                    'Fake VAE var')

        random_samples.append(fake_var)

    random_samples = torch.cat(random_samples, dim=0)
    from torchvision.utils import save_image
    save_image(random_samples, 'test.png', normalize=True)
    torch.save(random_samples,
               os.path.join(opt.saver.eval_dir, "random_samples.pth"))
    epoch_iterator.close()
Beispiel #5
0
def eval(opt, netG):
    # Re-generate dataset frames

    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [opt.batch_size, opt.latent_dim, *initial_size]

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    random_samples = []

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
        else:
            real = data.to(opt.device)

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))
        import numpy as np
        with torch.no_grad():
            fake_var = []
            fake_vae_var = []
            for _ in range(opt.num_samples):
                noise_init = utils.generate_noise(ref=noise_init)
                channel_idxs = np.random.choice(np.arange(0, 128),
                                                127,
                                                replace=False)
                noise_init[:] = 0
                noise_init[:, _:_ + 1, 8:11, 16:20] = 5
                fake, fake_vae = G_curr(noise_init,
                                        opt.Noise_Amps,
                                        noise_init=noise_init,
                                        mode="rand")
                fake_var.append(fake)
                fake_vae_var.append(fake_vae)
            fake_var = torch.cat(fake_var, dim=0)
            fake_vae_var = torch.cat(fake_vae_var, dim=0)

        opt.summary.visualize_image(opt, iteration, real, 'Real')
        opt.summary.visualize_image(opt, iteration, fake_var, 'Fake var')
        opt.summary.visualize_image(opt, iteration, fake_vae_var,
                                    'Fake VAE var')

        random_samples.append(fake_var)

    random_samples = torch.cat(random_samples, dim=0)
    from torchvision.utils import save_image
    save_image(random_samples, 'test.png')
    torch.save(random_samples,
               os.path.join(opt.saver.eval_dir, "random_samples.pth"))
    epoch_iterator.close()