Beispiel #1
0
    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)
Beispiel #2
0
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)
Beispiel #3
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=3).to(device)
    D = Discr(in_ch=4, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(10):
        for (x, x2), y in dl:
            x = x.expand(-1, 3, -1, -1).to(device)
            x2 = x2.to(device)
            y = y.to(device)

            z = torch.randn(y.size(0), 32).to(device)
            fake = G(z, x2)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(
                D(torch.cat([fake.detach(), x2], dim=1) * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(torch.cat([x, x2], dim=1) * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(torch.cat([fake, x2], dim=1) * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(x2, win='canny')
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1
Beispiel #4
0
def train_net(Gen, Discr):
    G = Gen(in_noise=32, out_ch=1).to(device)
    D = Discr(in_ch=1, out_ch=1, norm=None).to(device)

    opt_G = Adam(G.parameters(), lr=2e-4, betas=(0., 0.999))
    opt_D = Adam(D.parameters(), lr=2e-4, betas=(0., 0.999))

    iters = 0
    for epoch in range(4):
        for x, y in dl:
            x = x.to(device)
            y = y.to(device)

            z = torch.randn(16, 32).to(device)
            fake = G(z)

            opt_D.zero_grad()
            loss_fake = gan_loss.fake(D(fake.detach() * 2 - 1))
            loss_fake.backward()

            loss_true = gan_loss.real(D(x * 2 - 1))
            loss_true.backward()
            opt_D.step()

            opt_G.zero_grad()
            freeze(D)
            loss = gan_loss.generated(D(fake * 2 - 1))
            loss.backward()
            unfreeze(D)
            opt_G.step()

            if iters % 100 == 0:
                print("Iter {}, loss true {}, loss fake {}, loss G {}".format(
                    iters, loss_true.item(), loss_fake.item(), loss.item()))

            if iters % 10 == 0:
                vis.images(fake, opts=dict(store_history=True), win='fake')
            iters += 1
Beispiel #5
0
    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)
Beispiel #6
0
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)