Exemplo n.º 1
0
    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)
Exemplo n.º 2
0
    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        return test_fun(batch)
Exemplo n.º 3
0
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)
Exemplo n.º 4
0
    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        with torch.no_grad():
            out = test_fun(batch)

        D.train()
        G.train()
        return out
Exemplo n.º 5
0
 def __init__(self) -> None:
     super(NeuralStyleLoss, self).__init__()
     self.style_layers = [
         'conv1_1',
         'conv2_1',
         'conv3_1',
         'conv4_1',
         'conv5_1',
     ]
     self.style_weights = {'conv5_1': 1.0, 'conv4_1': 1.0}
     self.style_hists = {}
     self.content_layers = ['conv3_2']
     self.hists_layers = ['conv5_1']
     self.net = PerceptualNet(self.style_layers + self.content_layers,
                              remove_unused_layers=False)
     tu.freeze(self.net)
     self.norm = ImageNetInputNorm()
Exemplo n.º 6
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=3).to(device)
    D = Discr(in_ch=4, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(10):
        for (x, x2), y in dl:
            x = x.expand(-1, 3, -1, -1).to(device)
            x2 = x2.to(device)
            y = y.to(device)

            z = torch.randn(y.size(0), 32).to(device)
            fake = G(z, x2)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(
                D(torch.cat([fake.detach(), x2], dim=1) * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(torch.cat([x, x2], dim=1) * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(torch.cat([fake, x2], dim=1) * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(x2, win='canny')
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1
Exemplo n.º 7
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=1).to(device)
    D = Discr(in_ch=1, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(4):
        for x, y in dl:
            x = x.to(device)
            y = y.to(device)

            z = torch.randn(16, 32).to(device)
            fake = G(z)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(D(fake.detach() * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(x * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(fake * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1
Exemplo n.º 8
0
def StyleGAN2Recipe(G: nn.Module,
                    D: nn.Module,
                    dataloader,
                    noise_size: int,
                    gpu_id: int,
                    total_num_gpus: int,
                    *,
                    G_lr: float = 2e-3,
                    D_lr: float = 2e-3,
                    tag: str = 'model',
                    ada: bool = True):
    """
    StyleGAN2 Recipe distributed with DistributedDataParallel

    Args:
        G (nn.Module): a Generator.
        D (nn.Module): a Discriminator.
        dataloader: a dataloader conforming to torchvision's API.
        noise_size (int): the size of the input noise vector.
        gpu_id (int): the GPU index on which to run.
        total_num_gpus (int): how many GPUs are they
        G_lr (float): RAdamW lr for G
        D_lr (float): RAdamW lr for D
        tag (str): tag for Visdom and checkpoints
        ada (bool): whether to enable Adaptive Data Augmentation

    Returns:
        recipe, G EMA model
    """
    G_polyak = copy.copy(G)

    G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id)
    D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id)
    print(G)
    print(D)

    optG: torch.optim.Optimizer = RAdamW(G.parameters(),
                                         G_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)
    optD: torch.optim.Optimizer = RAdamW(D.parameters(),
                                         D_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)

    batch_size = len(next(iter(dataloader))[0])
    diffTF = ADATF(-2 if not ada else -0.9,
                   50000 / (batch_size * total_num_gpus))

    ppl = PPL(4)

    def G_train(batch):
        with G.no_sync():
            pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id))
        ##############
        #   G pass   #
        ##############
        imgs = G(torch.randn(batch_size, noise_size, device=gpu_id))
        pred = D(diffTF(imgs) * 2 - 1)
        score = gan_loss.generated(pred)
        score.backward()

        return {'G_loss': score.item(), 'ppl': pl}

    gradient_penalty = GradientPenalty(0.1)

    def D_train(batch):
        ###################
        #    Fake pass    #
        ###################
        with D.no_sync():
            # Sync the gradient on the last backward
            noise = torch.randn(batch_size, noise_size, device=gpu_id)
            with torch.no_grad():
                fake = G(noise)
            fake.requires_grad_(True)
            fake.retain_grad()
            fake_tf = diffTF(fake) * 2 - 1
            fakeness = D(fake_tf).squeeze(1)
            fake_loss = gan_loss.fake(fakeness)
            fake_loss.backward()

            correct = (fakeness < 0).int().eq(1).float().sum()
        fake_grad = fake.grad.detach().norm(dim=1, keepdim=True)
        fake_grad /= fake_grad.max()

        tfmed = diffTF(batch[0]) * 2 - 1

        with D.no_sync():
            grad_norm = gradient_penalty(D, batch[0] * 2 - 1,
                                         fake.detach() * 2 - 1)

        ###################
        #    Real pass    #
        ###################
        real_out = D(tfmed)
        correct += (real_out > 0).detach().int().eq(1).float().sum()
        real_loss = gan_loss.real(real_out)
        real_loss.backward()
        pos_ratio = real_out.gt(0).float().mean().cpu().item()
        diffTF.log_loss(-pos_ratio)
        return {
            'imgs': fake.detach(),
            'i_grad': fake_grad,
            'loss': real_loss.item() + fake_loss.item(),
            'fake_loss': fake_loss.item(),
            'real_loss': real_loss.item(),
            'ADA-p': diffTF.p,
            'D-correct': correct / (2 * real_out.numel()),
            'grad_norm': grad_norm
        }

    tu.freeze(G_polyak)

    def test(batch):
        G_polyak.eval()

        def sample(N, n_iter, alpha=0.01, show_every=10):
            noise = torch.randn(N,
                                noise_size,
                                device=gpu_id,
                                requires_grad=True)
            opt = torch.optim.Adam([noise], lr=alpha)
            fakes = []
            for i in range(n_iter):
                noise += torch.randn_like(noise) / 10
                fake_batch = []
                opt.zero_grad()
                for j in range(0, N, batch_size):
                    with torch.enable_grad():
                        n_batch = noise[j:j + batch_size]
                        fake = G_polyak(n_batch, mixing=False)
                        fake_batch.append(fake)
                        log_prob = n_batch[:, 32:].pow(2).mul_(-0.5)
                        fakeness = -D(fake * 2 - 1).sum() - log_prob.sum()
                        fakeness.backward()
                opt.step()
                fake_batch = torch.cat(fake_batch, dim=0)

                if i % show_every == 0:
                    fakes.append(fake_batch.cpu().detach().clone())

            fakes.append(fake_batch.cpu().detach().clone())

            return torch.cat(fakes, dim=0)

        fake = sample(8, 50, alpha=0.001, show_every=10)

        noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1)
        noise = noise1 * t + noise2 * (1 - t)
        noise = noise.view(-1, noise_size)
        interp = torch.cat([
            G_polyak(n, mixing=False) for n in torch.split(noise, batch_size)
        ],
                           dim=0)
        return {
            'polyak_imgs': fake,
            'polyak_interp': interp,
        }

    recipe = GANRecipe(G,
                       D,
                       G_train,
                       D_train,
                       test,
                       dataloader,
                       visdom_env=tag if gpu_id == 0 else None,
                       log_every=10,
                       test_every=1000,
                       checkpoint=tag if gpu_id == 0 else None,
                       g_every=1)
    recipe.callbacks.add_callbacks([
        tcb.Log('batch.0', 'x'),
        tcb.WindowedMetricAvg('fake_loss'),
        tcb.WindowedMetricAvg('real_loss'),
        tcb.WindowedMetricAvg('grad_norm'),
        tcb.WindowedMetricAvg('ADA-p'),
        tcb.WindowedMetricAvg('D-correct'),
        tcb.Log('i_grad', 'img_grad'),
        tch.callbacks.Optimizer(optD),
    ])
    recipe.G_loop.callbacks.add_callbacks([
        tch.callbacks.Optimizer(optG),
        tcb.Polyak(G.module, G_polyak,
                   0.5**((batch_size * total_num_gpus) / 20000)),
        tcb.WindowedMetricAvg('ppl'),
    ])
    recipe.test_loop.callbacks.add_callbacks([
        tcb.Log('polyak_imgs', 'polyak'),
        tcb.Log('polyak_interp', 'interp'),
    ])
    recipe.register('G_polyak', G_polyak)
    recipe.to(gpu_id)
    return recipe, G_polyak
Exemplo n.º 9
0
    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)
Exemplo n.º 10
0
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)