Esempio n. 1
0
 def G_fun(batch) -> dict:
     x, y = batch
     G.train()
     D.train()
     out = G(x * 2 - 1)
     with D.no_sync():
         loss = gan_loss.generated(D(torch.cat([out, x], dim=1) * 2 - 1))
     loss += opts.l1_gain * F.l1_loss(out, y)
     loss.backward()
     return {'G_loss': loss.item()}
Esempio n. 2
0
    def G_train(batch):
        with G.no_sync():
            pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id))
        ##############
        #   G pass   #
        ##############
        imgs = G(torch.randn(batch_size, noise_size, device=gpu_id))
        pred = D(diffTF(imgs) * 2 - 1)
        score = gan_loss.generated(pred)
        score.backward()

        return {'G_loss': score.item(), 'ppl': pl}
Esempio n. 3
0
    def G_fun(batch) -> dict:
        x, y = batch
        D.eval()
        out = G(x * 2 - 1)
        with D.no_sync():
            loss = gan_loss.generated(D(out * 2 - 1))
        loss.backward(retain_graph=True)

        with M.no_sync():
            clf_loss = opts.consistency * M(out * 2 - 1, x * 2 - 1)['loss']

        # labels = torch.arange(len(matches), device=matches.device)
        # clf_loss = opts.consistency * F.cross_entropy( matches, torch.arange(len(matches), device=matches.device))
        clf_loss.backward()
        return {'G_loss': loss.item()}
Esempio n. 4
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=3).to(device)
    D = Discr(in_ch=4, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(10):
        for (x, x2), y in dl:
            x = x.expand(-1, 3, -1, -1).to(device)
            x2 = x2.to(device)
            y = y.to(device)

            z = torch.randn(y.size(0), 32).to(device)
            fake = G(z, x2)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(
                D(torch.cat([fake.detach(), x2], dim=1) * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(torch.cat([x, x2], dim=1) * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(torch.cat([fake, x2], dim=1) * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(x2, win='canny')
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1
Esempio n. 5
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=1).to(device)
    D = Discr(in_ch=1, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(4):
        for x, y in dl:
            x = x.to(device)
            y = y.to(device)

            z = torch.randn(16, 32).to(device)
            fake = G(z)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(D(fake.detach() * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(x * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(fake * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1