Exemplo n.º 1
0
# Apply the weights_init() function to randomly initialize all
# weights to mean=0.0, stddev=0.2
netD.apply(weights_init)
# Print the model.
print(netD)

# Binary Cross Entropy loss function.
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, params['nz'], 1, 1, device=device)

real_label = 1
fake_label = 0

# Optimizer for the discriminator.
optimizerD = optim.Adam(netD.parameters(),
                        lr=params['lr'],
                        betas=(params['beta1'], 0.999))
# Optimizer for the generator.
optimizerG = optim.Adam(netG.parameters(),
                        lr=params['lr'],
                        betas=(params['beta1'], 0.999))

# Stores generated images as training progresses.
img_list = []
# Stores generator losses during training.
G_losses = []
# Stores discriminator losses during training.
D_losses = []

iters = 0
Exemplo n.º 2
0
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True
    )

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)


    criterion = nn.BCELoss()

    # used for visualzing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    real_label = 1.
    fake_label = 0.

    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):

            real_cpu = data.to(device)
            b_size = real_cpu.size(0)

            # train netD
            label = torch.full((b_size,), real_label,
                               dtype=torch.float, device=device)
            netD.zero_grad()
            output = netD(real_cpu).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train netG
            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()
            netG.zero_grad()

            label.fill_(real_label)
            output = netD(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, epoch_num, step, len(trainloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/dcgan_epoch_%d.png' % epoch)
            plt.close()
    
    # save models
    torch.save(netG, './nets/dcgan_netG.pkl')
    torch.save(netD, './nets/dcgan_netD.pkl')
Exemplo n.º 3
0
class Solver(object):
    def __init__(self, train_loader, test_loader, config):
        # 训练集DataLoader
        self.train_loader = train_loader
        # 测试集DataLoader
        self.test_loader = test_loader
        # config配置
        self.config = config
        # 展示信息epoch次数
        self.show_every = config.show_every
        # 学习率衰退epoch数
        self.lr_decay_epoch = [
            15,
        ]
        # 创建模型
        self.build_model()
        # Loss function
        self.adversarial_loss = torch.nn.BCELoss()
        # 进入test模式
        if config.mode == 'test':
            print('Loading pre-trained model from %s...' % self.config.model)
            # 载入预训练模型并放入相应位置
            if self.config.cuda:
                self.netG.load_state_dict(torch.load(self.config.model))
                self.netD.load_state_dict(torch.load(self.config.model))
            else:
                self.netG.load_state_dict(
                    torch.load(self.config.model, map_location='cpu'))
                self.netD.load_state_dict(
                    torch.load(self.config.model, map_location='cpu'))

    # 打印网络信息和参数数量
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # 建立模型
    def build_model(self):
        self.netG = Generator(nz=self.config.nz,
                              ngf=self.config.ngf,
                              nc=self.config.nc)
        self.netD = Discriminator(nz=self.config.nz,
                                  ndf=self.config.ndf,
                                  nc=self.config.nc)
        # 是否将网络搬运至cuda
        if self.config.cuda:
            self.netG = self.net.cuda()
            self.netD = self.net.cuda()
            cudnn.benchmark = True
        # self.net.train()
        # 设置eval状态
        self.netG.eval()  # use_global_stats = True
        self.netD.eval()
        # 载入预训练模型或自行训练模型
        if self.config.load == '':
            self.netG.load_state_dict(torch.load(self.config.pretrained_model))
            self.netD.load_state_dict(torch.load(self.config.pretrained_model))
        else:
            self.netG.load_state_dict(torch.load(self.config.load))
            self.netD.load_state_dict(torch.load(self.config.load))

        # 设置优化器
        self.optimizerD = Adam(self.netD.parameters(),
                               lr=self.config.lr,
                               betas=(self.config.beta1, self.config.beta2),
                               weight_decay=self.config.wd)
        self.optimizerG = Adam(self.netG.parameters(),
                               lr=self.config.lr,
                               betas=(self.config.beta1, self.config.beta2),
                               weight_decay=self.config.wd)
        # 打印网络结构
        self.print_network(self.netG, 'Generator Structure')
        self.print_network(self.netD, 'Discriminator Structure')

    # testing状态
    def test(self):
        # 训练模式
        mode_name = 'enhanced'
        # 开始时间
        time_s = time.time()
        # images数量
        img_num = len(self.test_loader)
        for i, data_batch in enumerate(self.test_loader):
            # 获取image数据和name
            phone_image, _, name = data_batch['phone_image'], data_batch[
                'dslr_image'], data_batch['name']
            # testing状态
            with torch.no_grad():
                # 获取tensor数据并搬运指定设备
                images = torch.Tensor(phone_image)
                if self.config.cuda:
                    images = images.cuda()
                # 预测值
                preds = self.netG(images).cpu().data.numpy()
                # 创建image
                cv2.imwrite(
                    os.path.join(self.config.test_fold,
                                 name[:-4] + '_' + mode_name + '.png'), preds)
        # 结束时间
        time_e = time.time()
        print('Speed: %f FPS' % (img_num / (time_e - time_s)))
        print('Test Done!')

    # training状态
    def train(self):
        for epoch in range(self.config.epochs):
            for i, data_batch in enumerate(self.train_loader):
                # 获取image数据和name
                phone_image, _, _ = data_batch['phone_image'], data_batch[
                    'dslr_image'], data_batch['name']
                # Adversarial ground truths
                valid = torch.Tensor(phone_image.size(0), 1).fill_(1.0)
                fake = torch.Tensor(phone_image.size(0), 1).fill_(0.0)

                # -----------------
                #  Train Generator
                # -----------------

                self.optimizerG.zero_grad()

                # Sample noise as generator input
                z = torch.Tensor(
                    np.random.normal(0, 1,
                                     (phone_image.shape[0], self.config.nz)))

                # Generate a batch of images
                gen_imgs = self.generator(z)

                # Loss measures generator's ability to fool the discriminator
                g_loss = self.adversarial_loss(self.discriminator(gen_imgs),
                                               valid)

                g_loss.backward()
                self.optimizerG.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------

                self.optimizerD.zero_grad()

                # Measure discriminator's ability to classify real from generated samples
                real_loss = self.adversarial_loss(
                    self.discriminator(phone_image), valid)
                fake_loss = self.adversarial_loss(
                    self.discriminator(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                self.optimizerD.step()

                # 展示此时信息
                if i % (self.show_every // self.config.batch_size) == 0:
                    print(
                        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                        % (epoch, self.config.epochs, i, len(
                            self.train_loader), d_loss.item(), g_loss.item()))
                    print('Learning rate: ' + str(self.config.lr))

            # 保存训练模型
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.netG.state_dict(),
                    '%s/models/generator/epoch_%d.pth' %
                    (self.config.save_folder, epoch + 1))
                torch.save(
                    self.netD.state_dict(),
                    '%s/models/discriminator/epoch_%d.pth' %
                    (self.config.save_folder, epoch + 1))

            # 学习率衰退
            if epoch in self.lr_decay_epoch:
                self.lr = self.lr * 0.1
                # 设置优化器
                self.optimizerG = Adam(
                    filter(lambda p: p.requires_grad,
                           self.netG.parameters(),
                           lr=self.config.lr,
                           betas=(self.config.beta1, self.config.beta2),
                           weight_decay=self.config.wd))
                self.optimizerD = Adam(
                    filter(lambda p: p.requires_grad,
                           self.netD.parameters(),
                           lr=self.config.lr,
                           betas=(self.config.beta1, self.config.beta2),
                           weight_decay=self.config.wd))

        # 保存训练模型
        torch.save(self.net.state_dict(),
                   '%s/models/generator/final.pth' % self.config.save_folder)
        torch.save(
            self.net.state_dict(),
            '%s/models/discriminator/final.pth' % self.config.save_folder)
Exemplo n.º 4
0
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True)

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)

    # used for visualizing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    # optimizers
    optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
    optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):
            # training netD
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            netD.zero_grad()

            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)

            loss_D = -torch.mean(netD(real_cpu)) + torch.mean(netD(fake))
            loss_D.backward()
            optimizerD.step()

            for p in netD.parameters():
                p.data.clamp_(-clip_value, clip_value)

            if step % n_critic == 0:
                # training netG
                noise = torch.randn(b_size, nz, 1, device=device)

                netG.zero_grad()
                fake = netG(noise)
                loss_G = -torch.mean(netD(fake))

                netD.zero_grad()
                netG.zero_grad()
                loss_G.backward()
                optimizerG.step()

            if step % 5 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' %
                      (epoch, epoch_num, step, len(trainloader), loss_D.item(),
                       loss_G.item()))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/wgan_epoch_%d.png' % epoch)
            plt.close()
    # save model
    torch.save(netG, './nets/wgan_netG.pkl')
    torch.save(netD, './nets/wgan_netD.pkl')
Exemplo n.º 5
0
def main():
    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # DataParallel
    generator = nn.DataParallel(generator).to(device)
    discriminator = nn.DataParallel(discriminator).to(device)

    # Dataloader
    # data preparation, loaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # cudnn.benchmark = True

    # preparing the training laoder
    train_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.RandomCrop(
                    128),  # random crop within the center crop 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='train'),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.workers,
        pin_memory=True)
    print('Training loader prepared.')

    # preparing validation loader
    val_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='val'),
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.workers,
        pin_memory=True)
    print('Validation loader prepared.')

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------
    for epoch in range(opt.n_epochs):
        pbar = tqdm(total=len(train_loader))

        start_time = time.time()
        for i, data in enumerate(train_loader):

            input_var = list()
            for j in range(len(data)):
                # if j>1:
                input_var.append(data[j].to(device))

            imgs = input_var[0]
            # Adversarial ground truths
            valid = np.ones((imgs.shape[0], 1))
            valid = torch.FloatTensor(valid).to(device)
            fake = np.zeros((imgs.shape[0], 1))
            fake = torch.FloatTensor(fake).to(device)
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            # Sample noise as generator input
            z = np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))
            z = torch.FloatTensor(z).to(device)
            # Generate a batch of images
            gen_imgs = generator(z, input_var[1], input_var[2], input_var[3],
                                 input_var[4])

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            pbar.update(1)

        pbar.close()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time Elapsed: %f]"
            % (epoch, opt.n_epochs, i, len(train_loader), d_loss.item(),
               g_loss.item(), time.time() - start_time))

        if epoch % opt.sample_interval == 0:
            save_samples(epoch, gen_imgs.data[:25])
            save_model(epoch, generator.state_dict(),
                       discriminator.state_dict())
Exemplo n.º 6
0
# Handle multi-gpu if desired
# if (device.type == 'cuda') and (ngpu > 1):
#     netG = nn.DataParallel(netG, list(range(ngpu)))
#     netD = nn.DataParallel(netD, list(range(ngpu)))
# netG.apply(weights_init)
# netD.apply(weights_init)
netG = netG.cuda()
netD = netD.cuda()

# Initialize BCELoss function
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, nz + t_in, 1, 1, device=device)

real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(start_epoch, num_epochs):
    for i, (img, t_v) in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = img.to(device)
Exemplo n.º 7
0
    transforms.Normalize([0.5 for _ in range(channels_img)],
                         [0.5 for _ in range(channels_img)]),
])

dataset = datasets.MNIST(root='../datasets/',
                         train=True,
                         transform=transforms,
                         download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
gen = Generator(z_dim, channels_img, features_gen).to(device)
disc = Discriminator(channels_img, features_disc).to(device)
load_model(disc, disc_file, device)
load_model(gen, gen_file, device)

opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=learning_rate, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)
writer_real = SummaryWriter('runs/dcgan_mnist/real')
writer_fake = SummaryWriter('runs/dcgan_mnist/fake')
step = 0

gen.train()
disc.train()

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(noise)
def main():
    # load training data
    trainset = Dataset('./data/brilliant_blue')

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True)

    # init netD and netG
    netD = Discriminator().to(device)
    netD.apply(weights_init)

    netG = Generator(nz).to(device)
    netG.apply(weights_init)

    # used for visualizing training process
    fixed_noise = torch.randn(16, nz, 1, device=device)

    # optimizers
    # optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
    # optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
    optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
    optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

    for epoch in range(epoch_num):
        for step, (data, _) in enumerate(trainloader):
            # training netD
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            netD.zero_grad()

            noise = torch.randn(b_size, nz, 1, device=device)
            fake = netG(noise)

            # gradient penalty
            eps = torch.Tensor(b_size, 1, 1).uniform_(0, 1)
            x_p = eps * data + (1 - eps) * fake
            grad = autograd.grad(netD(x_p).mean(),
                                 x_p,
                                 create_graph=True,
                                 retain_graph=True)[0].view(b_size, -1)
            grad_norm = torch.norm(grad, 2, 1)
            grad_penalty = p_coeff * torch.pow(grad_norm - 1, 2)

            loss_D = torch.mean(netD(fake) - netD(real_cpu))
            loss_D.backward()
            optimizerD.step()

            for p in netD.parameters():
                p.data.clamp_(-0.01, 0.01)

            if step % n_critic == 0:
                # training netG
                noise = torch.randn(b_size, nz, 1, device=device)

                netG.zero_grad()
                fake = netG(noise)
                loss_G = -torch.mean(netD(fake))

                netD.zero_grad()
                netG.zero_grad()
                loss_G.backward()
                optimizerG.step()

            if step % 5 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' %
                      (epoch, epoch_num, step, len(trainloader), loss_D.item(),
                       loss_G.item()))

        # save training process
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            f, a = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    a[i][j].plot(fake[i * 4 + j].view(-1))
                    a[i][j].set_xticks(())
                    a[i][j].set_yticks(())
            plt.savefig('./img/wgan_gp_epoch_%d.png' % epoch)
            plt.close()
    # save model
    torch.save(netG, './nets/wgan_gp_netG.pkl')
    torch.save(netD, './nets/wgan_gp_netD.pkl')
Exemplo n.º 9
0
    dis_net = Discriminator().to(device)
    dis_net.apply(weights_init)

    # Imporant. We need to add noise to images to learn properly
    fixed_noise = torch.randn(config.batchSize, config.nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    criterion = nn.BCELoss()

    # We need 2 seperate optimizers, the Generator and the Discriminator
    gen_opt = optim.Adam(gen_net.parameters(),
                         lr=config.lr,
                         betas=(config.beta1, 0.999))

    dis_opt = optim.Adam(dis_net.parameters(),
                         lr=config.lr,
                         betas=(config.beta1, 0.999))

    # For checkpointing purposes
    max_err = 99999999999999999999

    for epoch in tqdm(range(config.EPOCHS)):
        err_gen, err_disc = engine.train_step(dataloader, criterion, gen_net,
                                              dis_net, gen_opt, dis_opt,
                                              device)
        print("Epochs = {}, Generator error = {}, Discriminator error = {}".
              format(epoch, err_gen, err_disc))

        if (err_gen + err_disc < max_err):
            print("Checkpointing the better model")
Exemplo n.º 10
0
def main():

    dataSize = 128
    batchSize = 8
    # imageSize = 32
    imageSize = 64
    initWithCats = True

    # discCheckpointPath = r'E:\projects\visus\PyTorch-GAN\implementations\dcgan\checkpoints\2020_07_10_15_53_34\disc_step4800.pth'
    # discCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netD_epoch_24.pth'
    discCheckpointPath = None

    gpu = torch.device('cuda')

    imageRootPath = r'E:\data\cat-vs-dog\cat'
    catDataset = CatDataset(
        imageSubdirPath=imageRootPath,
        transform=transforms.Compose([
            transforms.Resize((imageSize, imageSize)),
            # torchvision.transforms.functional.to_grayscale,
            transforms.ToTensor(),
            # transforms.Lambda(lambda x: torch.reshape(x, x.shape[1:])),
            transforms.Normalize([0.5], [0.5])
        ]))

    sampler = InfiniteSampler(catDataset)
    catLoader = DataLoader(catDataset, batch_size=batchSize, sampler=sampler)

    # Generate a random distance matrix.
    # # Make a matrix with positive values.
    # distancesCpu = np.clip(np.random.normal(0.5, 1.0 / 3, (dataSize, dataSize)), 0, 1)
    # # Make it symmetrical.
    # distancesCpu = np.matmul(distancesCpu, distancesCpu.T)

    # Generate random points and compute distances, guaranteeing that the triangle rule isn't broken.
    randomPoints = generate_points(dataSize)
    distancesCpu = scipy.spatial.distance_matrix(randomPoints,
                                                 randomPoints,
                                                 p=2)

    if initWithCats:
        imagePaths = random.choices(glob.glob(os.path.join(imageRootPath,
                                                           '*')),
                                    k=dataSize)
        catImages = []
        for p in imagePaths:
            image = skimage.transform.resize(imageio.imread(p),
                                             (imageSize, imageSize),
                                             1).transpose(2, 0, 1)
            catImages.append(image)

        imagesInitCpu = np.asarray(catImages)
    else:
        imagesInitCpu = np.clip(
            np.random.normal(0.5, 0.5 / 3,
                             (dataSize, 3, imageSize, imageSize)), 0, 1)

    images = torch.tensor(imagesInitCpu,
                          requires_grad=True,
                          dtype=torch.float32,
                          device=gpu)

    scale = torch.tensor(1.0,
                         requires_grad=True,
                         dtype=torch.float32,
                         device=gpu)

    lossModel = models.PerceptualLoss(model='net-lin', net='vgg',
                                      use_gpu=True).to(gpu)
    lossBce = torch.nn.BCELoss()

    # discriminator = Discriminator(imageSize, 3)
    discriminator = Discriminator(3, 64, 1)
    if discCheckpointPath:
        discriminator.load_state_dict(torch.load(discCheckpointPath))
    else:
        discriminator.init_params()
    discriminator = discriminator.to(gpu)

    optimizerImages = torch.optim.Adam([images, scale],
                                       lr=1e-3,
                                       betas=(0.9, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.999))
    optimizerDisc = torch.optim.Adam(discriminator.parameters(),
                                     lr=0.0002,
                                     betas=(0.5, 0.999))

    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(nrows=2, ncols=batchSize // 2)

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(1, 1, 1)

    outPath = os.path.join(
        'images',
        datetime.datetime.today().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(outPath)

    catIter = iter(catLoader)
    for batchIndex in range(10000):

        realImageBatch, _ = next(catIter)  # type: Tuple(torch.Tensor, Any)
        realImageBatch = realImageBatch.to(gpu)
        # realImageBatch = torch.tensor(realImageBatchCpu, device=gpu)

        # noinspection PyTypeChecker
        randomIndices = np.random.randint(
            0, dataSize, batchSize).tolist()  # type: List[int]
        # randomIndices = list(range(dataSize))  # type: List[int]
        distanceBatch = torch.tensor(
            distancesCpu[randomIndices, :][:, randomIndices],
            dtype=torch.float32,
            device=gpu)
        imageBatch = images[randomIndices].contiguous()

        distPred = lossModel.forward(
            imageBatch.repeat(repeats=(batchSize, 1, 1, 1)).contiguous(),
            imageBatch.repeat_interleave(repeats=batchSize,
                                         dim=0).contiguous(),
            normalize=True)
        distPredMat = distPred.reshape((batchSize, batchSize))

        lossDist = torch.sum((distanceBatch - distPredMat * scale)**2)  # MSE
        discPred = discriminator(imageBatch)
        lossRealness = lossBce(discPred,
                               torch.ones(imageBatch.shape[0], 1, device=gpu))
        lossImages = lossDist + 100.0 * lossRealness  # todo
        # lossImages = lossRealness  # todo

        optimizerImages.zero_grad()
        lossImages.backward()
        optimizerImages.step()

        lossDiscReal = lossBce(
            discriminator(realImageBatch),
            torch.ones(realImageBatch.shape[0], 1, device=gpu))
        lossDiscFake = lossBce(discriminator(imageBatch.detach()),
                               torch.zeros(imageBatch.shape[0], 1, device=gpu))
        lossDisc = (lossDiscFake + lossDiscReal) / 2
        # lossDisc = torch.tensor(0)

        optimizerDisc.zero_grad()
        lossDisc.backward()
        optimizerDisc.step()

        with torch.no_grad():
            # todo  We're clamping all the images every batch, can we do clamp only the ones updated?
            # images = torch.clamp(images, 0, 1)  # For some reason this was making the training worse.
            images.data = torch.clamp(images.data, 0, 1)

        if batchIndex % 100 == 0:
            msg = 'iter {}, loss images {:.3f}, loss dist {:.3f}, loss real {:.3f}, loss disc {:.3f}, scale: {:.3f}'.format(
                batchIndex, lossImages.item(), lossDist.item(),
                lossRealness.item(), lossDisc.item(), scale.item())
            print(msg)
            # print(discPred.tolist())
            imageBatchCpu = imageBatch.cpu().data.numpy().transpose(0, 2, 3, 1)
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(imageBatchCpu[i])
            fig.suptitle(msg)

            imagesAllCpu = images.cpu().data.numpy().transpose(0, 2, 3, 1)
            plot_image_scatter(ax2,
                               randomPoints,
                               imagesAllCpu,
                               downscaleRatio=2)

            fig.savefig(
                os.path.join(outPath, 'batch_{}.png'.format(batchIndex)))
            fig2.savefig(
                os.path.join(outPath, 'scatter_{}.png'.format(batchIndex)))
Exemplo n.º 11
0
def main():

    dataSize = 128
    batchSize = 8
    # imageSize = 32
    imageSize = 64

    # discCheckpointPath = r'E:\projects\visus\PyTorch-GAN\implementations\dcgan\checkpoints\2020_07_10_15_53_34\disc_step4800.pth'
    # discCheckpointPath = r'E:\projects\visus\pytorch-examples\dcgan\out\netD_epoch_24.pth'
    discCheckpointPath = None

    gpu = torch.device('cuda')

    # imageDataset = CatDataset(
    #     imageSubdirPath=r'E:\data\cat-vs-dog\cat',
    #     transform=transforms.Compose(
    #         [
    #             transforms.Resize((imageSize, imageSize)),
    #             transforms.ToTensor(),
    #             transforms.Normalize([0.5], [0.5])
    #         ]
    #     )
    # )

    imageDataset = datasets.CIFAR10(root=r'e:\data\images\cifar10', download=True,
                                    transform=transforms.Compose([
                                        transforms.Resize((imageSize, imageSize)),
                                        transforms.ToTensor(),
                                        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                        transforms.Normalize([0.5], [0.5]),
                               ]))

    # For now we normalize the vectors to have norm 1, but don't make sure
    # that the data has certain mean/std.
    pointDataset = AuthorDataset(
        jsonPath=r'E:\out\scripts\metaphor-vis\authors-all.json'
    )

    imageLoader = DataLoader(imageDataset, batch_size=batchSize, sampler=InfiniteSampler(imageDataset))
    pointLoader = DataLoader(pointDataset, batch_size=batchSize, sampler=InfiniteSampler(pointDataset))

    # Generate a random distance matrix.
    # # Make a matrix with positive values.
    # distancesCpu = np.clip(np.random.normal(0.5, 1.0 / 3, (dataSize, dataSize)), 0, 1)
    # # Make it symmetrical.
    # distancesCpu = np.matmul(distancesCpu, distancesCpu.T)

    # Generate random points and compute distances, guaranteeing that the triangle rule isn't broken.
    # randomPoints = generate_points(dataSize)
    # distancesCpu = scipy.spatial.distance_matrix(randomPoints, randomPoints, p=2)


    # catImagePath = os.path.expandvars(r'${DEV_METAPHOR_DATA_PATH}/cats/cat.247.jpg')
    # catImage = skimage.transform.resize(imageio.imread(catImagePath), (64, 64), 1).transpose(2, 0, 1)

    # imagesInitCpu = np.clip(np.random.normal(0.5, 0.5 / 3, (dataSize, 3, imageSize, imageSize)), 0, 1)
    # imagesInitCpu = np.clip(np.tile(catImage, (dataSize, 1, 1, 1)) + np.random.normal(0., 0.5 / 6, (dataSize, 3, 64, 64)), 0, 1)
    # images = torch.tensor(imagesInitCpu, requires_grad=True, dtype=torch.float32, device=gpu)

    scale = torch.tensor(4.0, requires_grad=True, dtype=torch.float32, device=gpu)

    lossModel = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True).to(gpu)
    bceLoss = torch.nn.BCELoss()

    # discriminator = Discriminator(imageSize, 3)
    discriminator = Discriminator(3, 64, 1)
    if discCheckpointPath:
        discriminator.load_state_dict(torch.load(discCheckpointPath))
    else:
        discriminator.init_params()

    discriminator = discriminator.to(gpu)

    generator = Generator(nz=pointDataset[0][0].shape[0], ngf=64)
    generator.init_params()
    generator = generator.to(gpu)

    # todo init properly, if training
    # discriminator.apply(weights_init_normal)

    # optimizerImages = torch.optim.Adam([images, scale], lr=1e-2, betas=(0.9, 0.999))
    optimizerScale = torch.optim.Adam([scale], lr=0.001)
    optimizerGen = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    # optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.999))
    optimizerDisc = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(nrows=2 * 2, ncols=batchSize // 2)

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(1, 1, 1)

    outPath = os.path.join('runs', datetime.datetime.today().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(outPath)

    imageIter = iter(imageLoader)
    pointIter = iter(pointLoader)
    for batchIndex in range(10000):

        imageBatchReal, _ = next(imageIter)  # type: Tuple(torch.Tensor, Any)
        imageBatchReal = imageBatchReal.to(gpu)
        # imageBatchReal = torch.tensor(realImageBatchCpu, device=gpu)

        # noinspection PyTypeChecker
        # randomIndices = np.random.randint(0, dataSize, batchSize).tolist()  # type: List[int]
        # # randomIndices = list(range(dataSize))  # type: List[int]
        # distanceBatch = torch.tensor(distancesCpu[randomIndices, :][:, randomIndices], dtype=torch.float32, device=gpu)
        # imageBatchFake = images[randomIndices].contiguous()
        vectorBatch, _ = next(pointIter)
        vectorBatch = vectorBatch.to(gpu)
        distanceBatch = l2_sqr_dist_matrix(vectorBatch)  # In-batch vector distances.

        imageBatchFake = generator(vectorBatch[:, :, None, None].float())

        # todo It's possible to compute this more efficiently, but would require re-implementing lpips.
        distImages = lossModel.forward(imageBatchFake.repeat(repeats=(batchSize, 1, 1, 1)).contiguous(),
                                       imageBatchFake.repeat_interleave(repeats=batchSize, dim=0).contiguous(), normalize=True)
        distPredMat = distImages.reshape((batchSize, batchSize))

        lossDist = torch.sum((distanceBatch - distPredMat * scale) ** 2)  # MSE
        discPred = discriminator(imageBatchFake)
        lossRealness = bceLoss(discPred, torch.ones(imageBatchFake.shape[0], device=gpu))
        lossGen = lossDist + 1.0 * lossRealness

        optimizerGen.zero_grad()
        optimizerScale.zero_grad()
        lossGen.backward()
        optimizerGen.step()
        optimizerScale.step()

        lossDiscReal = bceLoss(discriminator(imageBatchReal), torch.ones(imageBatchReal.shape[0], device=gpu))
        lossDiscFake = bceLoss(discriminator(imageBatchFake.detach()), torch.zeros(imageBatchFake.shape[0], device=gpu))
        lossDisc = (lossDiscFake + lossDiscReal) / 2
        # lossDisc = torch.tensor(0)

        optimizerDisc.zero_grad()
        lossDisc.backward()
        optimizerDisc.step()

        # with torch.no_grad():
        #     # todo  We're clamping all the images every batch, can we clamp only the ones updated?
        #     # images = torch.clamp(images, 0, 1)  # For some reason this was making the training worse.
        #     images.data = torch.clamp(images.data, 0, 1)

        if batchIndex % 100 == 0:
            msg = 'iter {}, loss gen {:.3f}, loss dist {:.3f}, loss real {:.3f}, loss disc {:.3f}, scale: {:.3f}'.format(
                batchIndex, lossGen.item(), lossDist.item(), lossRealness.item(), lossDisc.item(), scale.item()
            )
            print(msg)

            def gpu_images_to_numpy(images):
                imagesNumpy = images.cpu().data.numpy().transpose(0, 2, 3, 1)
                imagesNumpy = (imagesNumpy + 1) / 2

                return imagesNumpy

            # print(discPred.tolist())
            imageBatchFakeCpu = gpu_images_to_numpy(imageBatchFake)
            imageBatchRealCpu = gpu_images_to_numpy(imageBatchReal)
            for i, ax in enumerate(axes.flatten()[:batchSize]):
                ax.imshow(imageBatchFakeCpu[i])
            for i, ax in enumerate(axes.flatten()[batchSize:]):
                ax.imshow(imageBatchRealCpu[i])
            fig.suptitle(msg)

            with torch.no_grad():
                points = np.asarray([pointDataset[i][0] for i in range(200)], dtype=np.float32)
                images = gpu_images_to_numpy(generator(torch.tensor(points[..., None, None], device=gpu)))

                authorVectorsProj = umap.UMAP(n_neighbors=5, random_state=1337).fit_transform(points)
                plot_image_scatter(ax2, authorVectorsProj, images, downscaleRatio=2)

            fig.savefig(os.path.join(outPath, f'batch_{batchIndex}.png'))
            fig2.savefig(os.path.join(outPath, f'scatter_{batchIndex}.png'))
            plt.close(fig)
            plt.close(fig2)

            with torch.no_grad():
                imageNumber = 48
                points = np.asarray([pointDataset[i][0] for i in range(imageNumber)], dtype=np.float32)
                imagesGpu = generator(torch.tensor(points[..., None, None], device=gpu))

                # Compute LPIPS distances, batch to avoid memory issues.
                bs = 8
                assert imageNumber % bs == 0
                distImages = np.zeros((imagesGpu.shape[0], imagesGpu.shape[0]))
                for i in range(imageNumber // bs):
                    startA, endA = i * bs, (i + 1) * bs 
                    imagesA = imagesGpu[startA:endA]
                    for j in range(imageNumber // bs):
                        startB, endB = j * bs, (j + 1) * bs
                        imagesB = imagesGpu[startB:endB]

                        distBatch = lossModel.forward(imagesA.repeat(repeats=(bs, 1, 1, 1)).contiguous(),
                                                      imagesB.repeat_interleave(repeats=bs, dim=0).contiguous(),
                                                      normalize=True).cpu().numpy()

                        distImages[startA:endA, startB:endB] = distBatch.reshape((bs, bs))

                # Move to the CPU and append an alpha channel for rendering.
                images = gpu_images_to_numpy(imagesGpu)
                images = [np.concatenate([im, np.ones(im.shape[:-1] + (1,))], axis=-1) for im in images]

                distPoints = l2_sqr_dist_matrix(torch.tensor(points, dtype=torch.double)).numpy()
                assert np.abs(distPoints - distPoints.T).max() < 1e-5
                distPoints = np.minimum(distPoints, distPoints.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 4.)
                render_distance_matrix(os.path.join(outPath, f'dist_point_{batchIndex}.png'),
                                       distPoints,
                                       images,
                                       config)

                assert np.abs(distImages - distImages.T).max() < 1e-5
                distImages = np.minimum(distImages, distImages.T)  # Remove rounding errors, guarantee symmetry.
                config = DistanceMatrixConfig()
                config.dataRange = (0., 1.)
                render_distance_matrix(os.path.join(outPath, f'dist_images_{batchIndex}.png'),
                                       distImages,
                                       images,
                                       config)

            torch.save(generator.state_dict(), os.path.join(outPath, 'gen_{}.pth'.format(batchIndex)))
            torch.save(discriminator.state_dict(), os.path.join(outPath, 'disc_{}.pth'.format(batchIndex)))
Exemplo n.º 12
0
def init_weight(model):
    classname = model.__class__.__name__
    if classname.find('conv') != -1:
        torch.nn.init.normal_(model.weight.data, 0, 0.02)


G.apply(init_weight)
D.apply(init_weight)

criterion_d = torch.nn.BCELoss()
criterion_g = torch.nn.MSELoss(
) if args.feature_matching else torch.nn.BCELoss()
optimizer_g = torch.optim.Adam(G.parameters(),
                               lr=args.learning_rate,
                               betas=[args.beta_1, args.beta_2])
optimizer_d = torch.optim.Adam(D.parameters(),
                               lr=args.learning_rate,
                               betas=[args.beta_1, args.beta_2])

if __name__ == "__main__":
    # One-sided label smoothing
    pos_labels = torch.full((args.batch_size, 1), args.alpha, device=device)
    neg_labels = torch.zeros((args.batch_size, 1), device=device)

    for epoch in range(args.epochs):
        losses_d, losses_g = [], []
        for i, data in enumerate(dataloader):
            # Train D with genuine data
            genuine = data[0].to(device)  # Drop label data
            genuine = genuine.reshape(-1, NUM_COLORS, IMAGE_SIZE, IMAGE_SIZE)
Exemplo n.º 13
0
        dped_dir=dped_dir,
        dataset_size=train_size,
        image_size=PATCH_SIZE,
        test=False,
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=nw,
    pin_memory=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=learning_rate,
                               betas=(0.5, 0.99))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=learning_rate,
                               betas=(0.5, 0.99))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(num_train_iters):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Tensor(imgs.shape[0], 1).fill_(1.0)
        fake = Tensor(imgs.shape[0], 1).fill_(0.0)
Exemplo n.º 14
0
def train(z_channels,
          c_channels,
          epoch_num,
          batch_size,
          lr=0.0002,
          beta1=0.5,
          model_path='models/dcgan_checkpoint.pth'):
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    if use_cuda:
        cudnn.benchmark = True
    else:
        print("*****   Warning: Cuda isn't available!  *****")

    loader = load_mnist(batch_size)

    generator = Generator(z_channels, c_channels).to(device)
    discriminator = Discriminator(c_channels).to(device)
    g_optimizer = optim.Adam(generator.parameters(),
                             lr=lr,
                             betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=lr,
                             betas=(beta1, 0.999))
    start_epoch = 0
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path)
        generator.load_state_dict(checkpoint['g'])
        discriminator.load_state_dict(checkpoint['d'])
        g_optimizer.load_state_dict(checkpoint['g_optim'])
        d_optimizer.load_state_dict(checkpoint['d_optim'])
        start_epoch = checkpoint['epoch'] + 1
    criterion = nn.BCELoss().to(device)

    generator.train()
    discriminator.train()
    std = 0.1
    for epoch in range(start_epoch, start_epoch + epoch_num):
        d_loss_sum, g_loss_sum = 0, 0
        print('----    epoch: %d    ----' % (epoch, ))
        for i, (real_image, number) in enumerate(loader):
            real_image = real_image.to(device)
            image_noise = torch.randn(real_image.size(),
                                      device=device).normal_(0, std)

            d_optimizer.zero_grad()
            real_label = torch.randn(number.size(),
                                     device=device).normal_(0.9, 0.1)
            real_image.add_(image_noise)
            out = discriminator(real_image)
            d_real_loss = criterion(out, real_label)
            d_real_loss.backward()

            noise_z = torch.randn((number.size(0), z_channels, 1, 1),
                                  device=device)
            fake_image = generator(noise_z)
            fake_label = torch.zeros(number.size(), device=device)
            fake_image = fake_image.add(image_noise)
            out = discriminator(fake_image.detach())
            d_fake_loss = criterion(out, fake_label)
            d_fake_loss.backward()

            d_optimizer.step()

            g_optimizer.zero_grad()
            out = discriminator(fake_image)
            g_loss = criterion(out, real_label)
            g_loss.backward()
            g_optimizer.step()

            d_loss_sum += d_real_loss.item() + d_fake_loss.item()
            g_loss_sum += g_loss.item()
            # if i % 10 == 0:
            #     print(d_loss, g_loss)
        print('d_loss: %f \t\t g_loss: %f' % (d_loss_sum /
                                              (i + 1), g_loss_sum / (i + 1)))
        std *= 0.9
        if epoch % 1 == 0:
            checkpoint = {
                'g': generator.state_dict(),
                'd': discriminator.state_dict(),
                'g_optim': g_optimizer.state_dict(),
                'd_optim': d_optimizer.state_dict(),
                'epoch': epoch,
            }
            save_image(fake_image,
                       'out/fake_samples_epoch_%03d.png' % (epoch, ),
                       normalize=False)
            torch.save(checkpoint, model_path)
            os.system('cp ' + model_path + ' models/model%d' % (epoch, ))
            print('saved!')
Exemplo n.º 15
0
# initialize discriminator weights
discriminator.apply(weights_init)
print('##### GENERATOR #####')
print(generator)
params = list(generator.parameters())
for i in range(13):
    print(params[i].size())  # conv1's .weight

print('######################')
print('\n##### DISCRIMINATOR #####')
print(discriminator)
print('######################')

# optimizers
optim_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# loss function
criterion = nn.BCELoss()

losses_g = []  # to store generator loss after each epoch
losses_d = []  # to store discriminator loss after each epoch


# function to train the discriminator network
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    # get the real label vector
    real_label = label_real(b_size).squeeze()
    # get the fake label vector
    fake_label = label_fake(b_size).squeeze()
    optimizer.zero_grad()
Exemplo n.º 16
0
# weights to mean=0.0, stddev=0.2
netD.load_state_dict(checkpoint['discriminator'])
#netD.apply(weights_init)
# Print the model.
#print(netD)

# Binary Cross Entropy loss function.
criterion = nn.BCELoss()

fixed_noise = torch.randn(128, params['nz'], 1, 1, device=device)

real_label = 1
fake_label = 0

# Optimizer for the discriminator.
optimizerD = optim.Adam(netD.parameters(), lr=params['lr'], betas=(params['beta1'], 0.999))
# Optimizer for the generator.
optimizerG = optim.Adam(netG.parameters(), lr=params['lr'], betas=(params['beta1'], 0.999))

# Stores generated images as training progresses.
img_list = []
# Stores generator losses during training.
G_losses = []
# Stores discriminator losses during training.
D_losses = []

iters = 0

print("Starting Training Loop...")
print("-"*25)
###
Exemplo n.º 17
0
def train():
    os.makedirs('log', exist_ok=True)

    ds = datasets.ImageFolder(root=data_root,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

    dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    net_g = Generator(n_latent_vector, n_g_filters).to(device)
    net_g.apply(weight_init)

    net_d = Discriminator(n_d_filters).to(device)
    net_d.apply(weight_init)

    if os.path.exists(model_save_path):
        all_state_dict = torch.load(model_save_path)
        net_d.load_state_dict(all_state_dict['d_state_dict'])
        net_g.load_state_dict(all_state_dict['g_state_dict'])
        print('model restored from {}'.format(model_save_path))

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(1, n_latent_vector, 1, 1, device=device)

    real_label = 1
    fake_label = 0

    optimizer_d = optim.Adam(net_d.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(net_g.parameters(), lr=lr, betas=(0.5, 0.999))

    print('start training...')
    
    try:
        for epoch in range(epochs):
            for i, data in enumerate(dataloader, 0):
                # update Discrinimator, maximize d loss
                net_d.zero_grad()
                real_cpu = data[0].to(device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size,), real_label, device=device)
                output = net_d(real_cpu).view(-1)
                err_d_real = criterion(output, label)
                err_d_real.backward()

                d_x = output.mean().item()

                # train with fake batch
                noise = torch.randn(b_size, n_latent_vector, 1, 1, device=device)
                fake = net_g(noise)
                label.fill_(fake_label)
                output = net_d(fake.detach()).view(-1)
                err_d_fake = criterion(output, label)
                err_d_fake.backward()

                d_g_z1 = output.mean().item()
                err_d = err_d_real + err_d_fake
                optimizer_d.step()

                # update Generator
                net_g.zero_grad()
                label.fill_(real_label)
                output = net_d(fake).view(-1)
                err_g = criterion(output, label)
                err_g.backward()
                d_g_z2 = output.mean().item()
                optimizer_d.step()

                if i % 50 == 0:
                    print(f'Epoch: {epoch}, loss_d: {err_d.item()}, loss_g: {err_g.item()}')
        
            if epoch % 2 == 0 and epoch != 0:
                with torch.no_grad():
                    fake = net_g(fixed_noise).detach().cpu().numpy()
                    print(fake.shape)
                    fake = np.transpose(np.squeeze(fake, axis=0), (1, 2, 0))
                    print(fake.shape)
                    cv2.imwrite('log/{}_fake.png'.format(epoch), fake)
                    print('record a fake image to local.')
        
    except KeyboardInterrupt:
        print('interrupted, try saving the model')
        all_state_dict = {
            'd_state_dict': net_d.state_dict(),
            'g_state_dict': net_g.state_dict(),
        }
        torch.save(all_state_dict, model_save_path)
        print('model saved...')