def G_wrap(batch): tu.freeze(D) D.eval() tu.unfreeze(G) G.train() return G_fun(batch)
def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() return test_fun(batch)
def D_wrap(batch): tu.freeze(G) G.eval() tu.unfreeze(D) D.train() return D_fun(batch)
def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() with torch.no_grad(): out = test_fun(batch) D.train() G.train() return out
def __init__(self) -> None: super(NeuralStyleLoss, self).__init__() self.style_layers = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1', ] self.style_weights = {'conv5_1': 1.0, 'conv4_1': 1.0} self.style_hists = {} self.content_layers = ['conv3_2'] self.hists_layers = ['conv5_1'] self.net = PerceptualNet(self.style_layers + self.content_layers, remove_unused_layers=False) tu.freeze(self.net) self.norm = ImageNetInputNorm()
def train_net(Gen, Discr): G = Gen(in_noise=32, out_ch=3).to(device) D = Discr(in_ch=4, out_ch=1, norm=None).to(device) opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999)) opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999)) iters = 0 for epoch in range(10): for (x, x2), y in dl: x = x.expand(-1, 3, -1, -1).to(device) x2 = x2.to(device) y = y.to(device) z = torch.randn(y.size(0), 32).to(device) fake = G(z, x2) opt_D.zero_grad() loss_fake = gan_loss.fake( D(torch.cat([fake.detach(), x2], dim=1) * 2 - 1)) loss_fake.backward() loss_true = gan_loss.real(D(torch.cat([x, x2], dim=1) * 2 - 1)) loss_true.backward() opt_D.step() opt_G.zero_grad() freeze(D) loss = gan_loss.generated(D(torch.cat([fake, x2], dim=1) * 2 - 1)) loss.backward() unfreeze(D) opt_G.step() if iters % 100 == 0: print("Iter {}, loss true {}, loss fake {}, loss G {}".format( iters, loss_true.item(), loss_fake.item(), loss.item())) if iters % 10 == 0: vis.images(x2, win='canny') vis.images(fake, opts=dict(store_history=True), win='fake') iters += 1
def train_net(Gen, Discr): G = Gen(in_noise=32, out_ch=1).to(device) D = Discr(in_ch=1, out_ch=1, norm=None).to(device) opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999)) opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999)) iters = 0 for epoch in range(4): for x, y in dl: x = x.to(device) y = y.to(device) z = torch.randn(16, 32).to(device) fake = G(z) opt_D.zero_grad() loss_fake = gan_loss.fake(D(fake.detach() * 2 - 1)) loss_fake.backward() loss_true = gan_loss.real(D(x * 2 - 1)) loss_true.backward() opt_D.step() opt_G.zero_grad() freeze(D) loss = gan_loss.generated(D(fake * 2 - 1)) loss.backward() unfreeze(D) opt_G.step() if iters % 100 == 0: print("Iter {}, loss true {}, loss fake {}, loss G {}".format( iters, loss_true.item(), loss_fake.item(), loss.item())) if iters % 10 == 0: vis.images(fake, opts=dict(store_history=True), win='fake') iters += 1
def StyleGAN2Recipe(G: nn.Module, D: nn.Module, dataloader, noise_size: int, gpu_id: int, total_num_gpus: int, *, G_lr: float = 2e-3, D_lr: float = 2e-3, tag: str = 'model', ada: bool = True): """ StyleGAN2 Recipe distributed with DistributedDataParallel Args: G (nn.Module): a Generator. D (nn.Module): a Discriminator. dataloader: a dataloader conforming to torchvision's API. noise_size (int): the size of the input noise vector. gpu_id (int): the GPU index on which to run. total_num_gpus (int): how many GPUs are they G_lr (float): RAdamW lr for G D_lr (float): RAdamW lr for D tag (str): tag for Visdom and checkpoints ada (bool): whether to enable Adaptive Data Augmentation Returns: recipe, G EMA model """ G_polyak = copy.copy(G) G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id) D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id) print(G) print(D) optG: torch.optim.Optimizer = RAdamW(G.parameters(), G_lr, betas=(0., 0.99), weight_decay=0) optD: torch.optim.Optimizer = RAdamW(D.parameters(), D_lr, betas=(0., 0.99), weight_decay=0) batch_size = len(next(iter(dataloader))[0]) diffTF = ADATF(-2 if not ada else -0.9, 50000 / (batch_size * total_num_gpus)) ppl = PPL(4) def G_train(batch): with G.no_sync(): pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id)) ############## # G pass # ############## imgs = G(torch.randn(batch_size, noise_size, device=gpu_id)) pred = D(diffTF(imgs) * 2 - 1) score = gan_loss.generated(pred) score.backward() return {'G_loss': score.item(), 'ppl': pl} gradient_penalty = GradientPenalty(0.1) def D_train(batch): ################### # Fake pass # ################### with D.no_sync(): # Sync the gradient on the last backward noise = torch.randn(batch_size, noise_size, device=gpu_id) with torch.no_grad(): fake = G(noise) fake.requires_grad_(True) fake.retain_grad() fake_tf = diffTF(fake) * 2 - 1 fakeness = D(fake_tf).squeeze(1) fake_loss = gan_loss.fake(fakeness) fake_loss.backward() correct = (fakeness < 0).int().eq(1).float().sum() fake_grad = fake.grad.detach().norm(dim=1, keepdim=True) fake_grad /= fake_grad.max() tfmed = diffTF(batch[0]) * 2 - 1 with D.no_sync(): grad_norm = gradient_penalty(D, batch[0] * 2 - 1, fake.detach() * 2 - 1) ################### # Real pass # ################### real_out = D(tfmed) correct += (real_out > 0).detach().int().eq(1).float().sum() real_loss = gan_loss.real(real_out) real_loss.backward() pos_ratio = real_out.gt(0).float().mean().cpu().item() diffTF.log_loss(-pos_ratio) return { 'imgs': fake.detach(), 'i_grad': fake_grad, 'loss': real_loss.item() + fake_loss.item(), 'fake_loss': fake_loss.item(), 'real_loss': real_loss.item(), 'ADA-p': diffTF.p, 'D-correct': correct / (2 * real_out.numel()), 'grad_norm': grad_norm } tu.freeze(G_polyak) def test(batch): G_polyak.eval() def sample(N, n_iter, alpha=0.01, show_every=10): noise = torch.randn(N, noise_size, device=gpu_id, requires_grad=True) opt = torch.optim.Adam([noise], lr=alpha) fakes = [] for i in range(n_iter): noise += torch.randn_like(noise) / 10 fake_batch = [] opt.zero_grad() for j in range(0, N, batch_size): with torch.enable_grad(): n_batch = noise[j:j + batch_size] fake = G_polyak(n_batch, mixing=False) fake_batch.append(fake) log_prob = n_batch[:, 32:].pow(2).mul_(-0.5) fakeness = -D(fake * 2 - 1).sum() - log_prob.sum() fakeness.backward() opt.step() fake_batch = torch.cat(fake_batch, dim=0) if i % show_every == 0: fakes.append(fake_batch.cpu().detach().clone()) fakes.append(fake_batch.cpu().detach().clone()) return torch.cat(fakes, dim=0) fake = sample(8, 50, alpha=0.001, show_every=10) noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1) noise = noise1 * t + noise2 * (1 - t) noise = noise.view(-1, noise_size) interp = torch.cat([ G_polyak(n, mixing=False) for n in torch.split(noise, batch_size) ], dim=0) return { 'polyak_imgs': fake, 'polyak_interp': interp, } recipe = GANRecipe(G, D, G_train, D_train, test, dataloader, visdom_env=tag if gpu_id == 0 else None, log_every=10, test_every=1000, checkpoint=tag if gpu_id == 0 else None, g_every=1) recipe.callbacks.add_callbacks([ tcb.Log('batch.0', 'x'), tcb.WindowedMetricAvg('fake_loss'), tcb.WindowedMetricAvg('real_loss'), tcb.WindowedMetricAvg('grad_norm'), tcb.WindowedMetricAvg('ADA-p'), tcb.WindowedMetricAvg('D-correct'), tcb.Log('i_grad', 'img_grad'), tch.callbacks.Optimizer(optD), ]) recipe.G_loop.callbacks.add_callbacks([ tch.callbacks.Optimizer(optG), tcb.Polyak(G.module, G_polyak, 0.5**((batch_size * total_num_gpus) / 20000)), tcb.WindowedMetricAvg('ppl'), ]) recipe.test_loop.callbacks.add_callbacks([ tcb.Log('polyak_imgs', 'polyak'), tcb.Log('polyak_interp', 'interp'), ]) recipe.register('G_polyak', G_polyak) recipe.to(gpu_id) return recipe, G_polyak
def G_wrap(batch): tu.freeze(D) tu.unfreeze(G) return G_fun(batch)
def D_wrap(batch): tu.freeze(G) tu.unfreeze(D) return D_fun(batch)