Пример #1
0
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True
    )

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)


    criterion = nn.BCELoss()

    # used for visualzing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    real_label = 1.
    fake_label = 0.

    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):

            real_cpu = data.to(device)
            b_size = real_cpu.size(0)

            # train netD
            label = torch.full((b_size,), real_label,
                               dtype=torch.float, device=device)
            netD.zero_grad()
            output = netD(real_cpu).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train netG
            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()
            netG.zero_grad()

            label.fill_(real_label)
            output = netD(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, epoch_num, step, len(trainloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/dcgan_epoch_%d.png' % epoch)
            plt.close()
    
    # save models
    torch.save(netG, './nets/dcgan_netG.pkl')
    torch.save(netD, './nets/dcgan_netD.pkl')
Пример #2
0
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True)

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)

    # used for visualizing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    # optimizers
    optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
    optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):
            # training netD
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            netD.zero_grad()

            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)

            loss_D = -torch.mean(netD(real_cpu)) + torch.mean(netD(fake))
            loss_D.backward()
            optimizerD.step()

            for p in netD.parameters():
                p.data.clamp_(-clip_value, clip_value)

            if step % n_critic == 0:
                # training netG
                noise = torch.randn(b_size, nz, 1, device=device)

                netG.zero_grad()
                fake = netG(noise)
                loss_G = -torch.mean(netD(fake))

                netD.zero_grad()
                netG.zero_grad()
                loss_G.backward()
                optimizerG.step()

            if step % 5 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' %
                      (epoch, epoch_num, step, len(trainloader), loss_D.item(),
                       loss_G.item()))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/wgan_epoch_%d.png' % epoch)
            plt.close()
    # save model
    torch.save(netG, './nets/wgan_netG.pkl')
    torch.save(netD, './nets/wgan_netD.pkl')
Пример #3
0
#
#plt.show()

# Create the generator.
netG = Generator(params).to(device)
# Apply the weights_init() function to randomly initialize all
# weights to mean=0.0, stddev=0.2
netG.apply(weights_init)
# Print the model.
print(netG)

# Create the discriminator.
netD = Discriminator(params).to(device)
# Apply the weights_init() function to randomly initialize all
# weights to mean=0.0, stddev=0.2
netD.apply(weights_init)
# Print the model.
print(netD)

# Binary Cross Entropy loss function.
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, params['nz'], 1, 1, device=device)

real_label = 1
fake_label = 0

# Optimizer for the discriminator.
optimizerD = optim.Adam(netD.parameters(),
                        lr=params['lr'],
                        betas=(params['beta1'], 0.999))
Пример #4
0
                            train=True,
                            download=True,
                            transform=transform)
mini_train_data, mnist_restset = torch.utils.data.random_split(
    train_data, [int(0.9 * len(train_data)),
                 int(0.1 * len(train_data))])
train_loader = DataLoader(mini_train_data, batch_size=batch_size, shuffle=True)

# initialize models
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)

# initialize generator weights
generator.apply(weights_init)
# initialize discriminator weights
discriminator.apply(weights_init)
print('##### GENERATOR #####')
print(generator)
params = list(generator.parameters())
for i in range(13):
    print(params[i].size())  # conv1's .weight

print('######################')
print('\n##### DISCRIMINATOR #####')
print(discriminator)
print('######################')

# optimizers
optim_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# loss function
Пример #5
0
def main():
    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # DataParallel
    generator = nn.DataParallel(generator).to(device)
    discriminator = nn.DataParallel(discriminator).to(device)

    # Dataloader
    # data preparation, loaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # cudnn.benchmark = True

    # preparing the training laoder
    train_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.RandomCrop(
                    128),  # random crop within the center crop 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='train'),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.workers,
        pin_memory=True)
    print('Training loader prepared.')

    # preparing validation loader
    val_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='val'),
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.workers,
        pin_memory=True)
    print('Validation loader prepared.')

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------
    for epoch in range(opt.n_epochs):
        pbar = tqdm(total=len(train_loader))

        start_time = time.time()
        for i, data in enumerate(train_loader):

            input_var = list()
            for j in range(len(data)):
                # if j>1:
                input_var.append(data[j].to(device))

            imgs = input_var[0]
            # Adversarial ground truths
            valid = np.ones((imgs.shape[0], 1))
            valid = torch.FloatTensor(valid).to(device)
            fake = np.zeros((imgs.shape[0], 1))
            fake = torch.FloatTensor(fake).to(device)
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            # Sample noise as generator input
            z = np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))
            z = torch.FloatTensor(z).to(device)
            # Generate a batch of images
            gen_imgs = generator(z, input_var[1], input_var[2], input_var[3],
                                 input_var[4])

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            pbar.update(1)

        pbar.close()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time Elapsed: %f]"
            % (epoch, opt.n_epochs, i, len(train_loader), d_loss.item(),
               g_loss.item(), time.time() - start_time))

        if epoch % opt.sample_interval == 0:
            save_samples(epoch, gen_imgs.data[:25])
            save_model(epoch, generator.state_dict(),
                       discriminator.state_dict())
Пример #6
0
import torch.nn.init as init
# from config import *
import config
import torch.optim as optim
import data

if __name__ == "__main__":
    # Create the nets
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset, dataloader = data.create_dataset()

    gen_net = Generator().to(device)
    gen_net.apply(weights_init)

    dis_net = Discriminator().to(device)
    dis_net.apply(weights_init)

    # Imporant. We need to add noise to images to learn properly
    fixed_noise = torch.randn(config.batchSize, config.nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    criterion = nn.BCELoss()

    # We need 2 seperate optimizers, the Generator and the Discriminator
    gen_opt = optim.Adam(gen_net.parameters(),
                         lr=config.lr,
                         betas=(config.beta1, 0.999))

    dis_opt = optim.Adam(dis_net.parameters(),
                         lr=config.lr,
Пример #7
0
                                             drop_last=True)
else:
    raise Exception("Not a valid dataset")

G = Generator(args.num_noises, NUM_COLORS, args.depths, IMAGE_SIZE).to(device)
D = Discriminator(NUM_COLORS, args.depths, IMAGE_SIZE).to(device)


def init_weight(model):
    classname = model.__class__.__name__
    if classname.find('conv') != -1:
        torch.nn.init.normal_(model.weight.data, 0, 0.02)


G.apply(init_weight)
D.apply(init_weight)

criterion_d = torch.nn.BCELoss()
criterion_g = torch.nn.MSELoss(
) if args.feature_matching else torch.nn.BCELoss()
optimizer_g = torch.optim.Adam(G.parameters(),
                               lr=args.learning_rate,
                               betas=[args.beta_1, args.beta_2])
optimizer_d = torch.optim.Adam(D.parameters(),
                               lr=args.learning_rate,
                               betas=[args.beta_1, args.beta_2])

if __name__ == "__main__":
    # One-sided label smoothing
    pos_labels = torch.full((args.batch_size, 1), args.alpha, device=device)
    neg_labels = torch.zeros((args.batch_size, 1), device=device)
Пример #8
0
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

cuda = True if torch.cuda.is_available() else False
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
dataloader = torch.utils.data.DataLoader(
    LoadData(
        dped_dir=dped_dir,
        dataset_size=train_size,
        image_size=PATCH_SIZE,
        test=False,
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=nw,
    pin_memory=True,
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True)

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)

    # used for visualizing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    # optimizers
    # optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
    # optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
    optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
    optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):
            # training netD
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            netD.zero_grad()

            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)

            # gradient penalty
            eps = torch.Tensor(b_size, 1, 1).uniform_(0, 1)
            x_p = eps * data + (1 - eps) * fake
            grad = autograd.grad(netD(x_p).mean(),
                                 x_p,
                                 create_graph=True,
                                 retain_graph=True)[0].view(b_size, -1)
            grad_norm = torch.norm(grad, 2, 1)
            grad_penalty = p_coeff * torch.pow(grad_norm - 1, 2)

            loss_D = torch.mean(netD(fake) - netD(real_cpu))
            loss_D.backward()
            optimizerD.step()

            for p in netD.parameters():
                p.data.clamp_(-0.01, 0.01)

            if step % n_critic == 0:
                # training netG
                noise = torch.randn(b_size, nz, 1, device=device)

                netG.zero_grad()
                fake = netG(noise)
                loss_G = -torch.mean(netD(fake))

                netD.zero_grad()
                netG.zero_grad()
                loss_G.backward()
                optimizerG.step()

            if step % 5 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' %
                      (epoch, epoch_num, step, len(trainloader), loss_D.item(),
                       loss_G.item()))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/wgan_gp_epoch_%d.png' % epoch)
            plt.close()
    # save model
    torch.save(netG, './nets/wgan_gp_netG.pkl')
    torch.save(netD, './nets/wgan_gp_netD.pkl')
Пример #10
0
def train():
    os.makedirs('log', exist_ok=True)

    ds = datasets.ImageFolder(root=data_root,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

    dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    net_g = Generator(n_latent_vector, n_g_filters).to(device)
    net_g.apply(weight_init)

    net_d = Discriminator(n_d_filters).to(device)
    net_d.apply(weight_init)

    if os.path.exists(model_save_path):
        all_state_dict = torch.load(model_save_path)
        net_d.load_state_dict(all_state_dict['d_state_dict'])
        net_g.load_state_dict(all_state_dict['g_state_dict'])
        print('model restored from {}'.format(model_save_path))

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(1, n_latent_vector, 1, 1, device=device)

    real_label = 1
    fake_label = 0

    optimizer_d = optim.Adam(net_d.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(net_g.parameters(), lr=lr, betas=(0.5, 0.999))

    print('start training...')
    
    try:
        for epoch in range(epochs):
            for i, data in enumerate(dataloader, 0):
                # update Discrinimator, maximize d loss
                net_d.zero_grad()
                real_cpu = data[0].to(device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size,), real_label, device=device)
                output = net_d(real_cpu).view(-1)
                err_d_real = criterion(output, label)
                err_d_real.backward()

                d_x = output.mean().item()

                # train with fake batch
                noise = torch.randn(b_size, n_latent_vector, 1, 1, device=device)
                fake = net_g(noise)
                label.fill_(fake_label)
                output = net_d(fake.detach()).view(-1)
                err_d_fake = criterion(output, label)
                err_d_fake.backward()

                d_g_z1 = output.mean().item()
                err_d = err_d_real + err_d_fake
                optimizer_d.step()

                # update Generator
                net_g.zero_grad()
                label.fill_(real_label)
                output = net_d(fake).view(-1)
                err_g = criterion(output, label)
                err_g.backward()
                d_g_z2 = output.mean().item()
                optimizer_d.step()

                if i % 50 == 0:
                    print(f'Epoch: {epoch}, loss_d: {err_d.item()}, loss_g: {err_g.item()}')
        
            if epoch % 2 == 0 and epoch != 0:
                with torch.no_grad():
                    fake = net_g(fixed_noise).detach().cpu().numpy()
                    print(fake.shape)
                    fake = np.transpose(np.squeeze(fake, axis=0), (1, 2, 0))
                    print(fake.shape)
                    cv2.imwrite('log/{}_fake.png'.format(epoch), fake)
                    print('record a fake image to local.')
        
    except KeyboardInterrupt:
        print('interrupted, try saving the model')
        all_state_dict = {
            'd_state_dict': net_d.state_dict(),
            'g_state_dict': net_g.state_dict(),
        }
        torch.save(all_state_dict, model_save_path)
        print('model saved...')