示例#1
0
def load_model(model_file, hidden_size, upsampling, cuda=False):
    if cuda:
        from_before = torch.load(model_file)
    else:
        from_before = torch.load(model_file,
                                 map_location=lambda storage, loc: storage)
    total_examples = from_before['total_examples']
    gen_losses = from_before['gen_losses']
    disc_losses = from_before['disc_losses']
    gen_loss_per_epoch = from_before['gen_loss_per_epoch']
    disc_loss_per_epoch = from_before['disc_loss_per_epoch']
    gen_state_dict = from_before['gen_state_dict']
    disc_state_dict = from_before['disc_state_dict']
    fixed_noise = from_before['fixed_noise']
    epoch = from_before['epoch']

    # load generator and discriminator
    if upsampling == 'transpose':
        from models.model import Generator, Discriminator
    elif upsampling == 'nn':
        from models.model_nn import Generator, Discriminator
    elif upsampling == 'bilinear':
        from models.model_bilinear import Generator, Discriminator

    gen = Generator(hidden_dim=hidden_size,
                    dropout=0.4)  # TODO: save dropout in checkpoint
    disc = Discriminator(leaky=0.2, dropout=0.4)  # TODO: same here
    disc.load_state_dict(disc_state_dict)
    gen.load_state_dict(gen_state_dict)
    return total_examples, fixed_noise, gen_losses, disc_losses, \
           gen_loss_per_epoch, disc_loss_per_epoch, epoch, gen, disc
示例#2
0
 def test_shape_g(self):
     gen = Generator(self.noise_dim, self.in_channels, 8)
     self.assertEqual(
         gen(self.z).shape, (self.N, self.in_channels, self.H, self.W))
     print(
         f"{gen(self.z).shape} == {(self.N, self.in_channels, self.H, self.W)}"
     )
示例#3
0
def train():
    args = parse_args()

    cfg = Config.from_file(args.config)

    # Dimensionality of the latent vector.
    latent_size = cfg.models.generator.z_dim
    # Use sigmoid activation for the last layer?
    cfg.models.discriminator.sigmoid_at_end = cfg.train.loss_type in [
        'ls', 'gan'
    ]

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    G = Generator(model_cfg=cfg.models.generator,
                  target_size=cfg.train.target_size)
    D = Discriminator(model_cfg=cfg.models.discriminator,
                      target_size=cfg.train.target_size)
    #print(G)
    #print(D)
    dataset = FaceDataset(cfg.train.dataset)
    assert len(dataset) > 0
    print(f'train dataset contains {len(dataset)} images.')
    clip = cfg.models.generator.z_clipping if hasattr(cfg.models.generator,
                                                      'z_zlipping') else None
    z_generator = RandomNoiseGenerator(cfg.models.generator.z_dim,
                                       'gaussian',
                                       clip=clip)
    pggan = PGGAN(G, D, dataset, z_generator, args.gpu, cfg, args.resume)
    pggan.train()
示例#4
0
    def model_init(self):
        self.D = Discriminator(self.model_type, self.image_size,
                               self.hidden_dim, self.n_filter, self.n_repeat)
        self.G = Generator(self.model_type, self.image_size, self.hidden_dim,
                           self.n_filter, self.n_repeat)

        self.D = cuda(self.D, self.cuda)
        self.G = cuda(self.G, self.cuda)

        self.D.weight_init(mean=0.0, std=0.02)
        self.G.weight_init(mean=0.0, std=0.02)

        self.D_optim = optim.Adam(self.D.parameters(),
                                  lr=self.D_lr,
                                  betas=(0.5, 0.999))
        self.G_optim = optim.Adam(self.G.parameters(),
                                  lr=self.G_lr,
                                  betas=(0.5, 0.999))

        #self.D_optim_scheduler = lr_scheduler.ExponentialLR(self.D_optim, gamma=0.97)
        #self.G_optim_scheduler = lr_scheduler.ExponentialLR(self.G_optim, gamma=0.97)
        self.D_optim_scheduler = lr_scheduler.StepLR(self.D_optim,
                                                     step_size=1,
                                                     gamma=0.5)
        self.G_optim_scheduler = lr_scheduler.StepLR(self.G_optim,
                                                     step_size=1,
                                                     gamma=0.5)

        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        if self.load_ckpt:
            self.load_checkpoint()
示例#5
0
def load_model(opt):

    print('initializing model ...')
    print("LOADING GENERATOR MODEL")
    model_g = Generator(100, gpu=opt.SYSTEM.USE_GPU)
    print("LOADING DISCRIMINATOR MODEL")

    model_d = Discriminator(output_dim=opt.TRAIN.TOTAL_FEATURES,
                            gpu=opt.SYSTEM.USE_GPU)

    return model_g, model_d
示例#6
0
    def test_generator_shape(self):
        gen_sle = Generator(img_size=self.img_size,
                            in_channels=self.in_channels,
                            z_dim=self.z_dim,
                            res_type=self.res_type[0])
        gen_gc = Generator(img_size=self.img_size,
                           in_channels=self.in_channels,
                           z_dim=self.z_dim,
                           res_type=self.res_type[1])

        noise = torch.randn([1, 256, 1, 1])

        z_gen_sle = gen_sle(noise)
        z_gen_gc = gen_gc(noise)

        real_shape = (1, 3, self.img_size, self.img_size)

        self.assertEqual(z_gen_sle.shape, z_gen_gc.shape)
        self.assertEqual(z_gen_sle.shape, real_shape)
        self.assertEqual(z_gen_gc.shape, real_shape)
示例#7
0
    def test_shape(self):
        gen = Generator(self.Z_DIM, self.IN_CHANNELS, img_channels=3)
        critic = Critic(self.IN_CHANNELS, img_channels=3)

        for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
            num_steps = int(log2(img_size / 4))
            x = torch.randn((1, self.Z_DIM, 1, 1))
            z = gen(x, 0.5, steps=num_steps)
            self.assertEqual(z.shape, (1, 3, img_size, img_size))
            out = critic(z, alpha=0.5, steps=num_steps)
            self.assertEqual(out.shape, (1, 1))
            print(f"Done! Img_size:{img_size}")
示例#8
0
    def models_initialize(self):
        print("-- Preparing Models --")

        self.generator = Generator(ngf=self.args.MODEL.ngf,
                                   input_nc=self.args.MODEL.input_nc,
                                   output_nc=self.args.MODEL.output_nc)

        self.discriminator = NLayerDiscriminator(
            ndf=self.args.MODEL.ndf,
            input_nc=self.args.MODEL.input_nc + self.args.MODEL.output_nc,
            n_layers=self.args.MODEL.n_layers)

        self.generator.to(self.device)
        self.discriminator.to(self.device)

        print("-- Models DONE --")
示例#9
0
    def test_fid_with_images(self):
        gen = Generator(1024)
        fid_model = InceptionV3FID(torch.device('cpu'))

        real_dataset = ImgFolderDataset('', fid=True)
        real_dataloader = get_sample_dataloader(real_dataset, num_samples=4,
                                                batch_size=2)

        noise = torch.randn([len(real_dataloader), 256, 1, 1])
        fake_images = []
        for batch in noise:
            fake_images.append(gen(batch.unsqueeze(0)))

        noise_dataset = FIDNoiseDataset(fake_images)
        fake_dataloader = DataLoader(noise_dataset, batch_size=2)

        fid = fid_model.get_fid_score(real_dataloader, fake_dataloader)

        print(fid)
    def __init__(self, args, device):
        self.args = args
        self.device = device

        # Load generator
        self.generator = Generator(
            args.size,
            args.latent,
            args.n_mlp,
            channel_multiplier=args.channel_multiplier).to(device)

        if args.generator_ckpt is not None:
            print("load generator:", args.generator_ckpt)
            generator_checkpoint = torch.load(
                args.generator_ckpt
            )  # map_location=lambda storage, loc: storage)
            self.generator.load_state_dict(
                generator_checkpoint['g_ema'],
                strict=False)  # TODO: Maybe need to load g

        # Load classifier
        resnet = resnet50(pretrained=True, num_classes=1)
        self.attribute_classifier = resnet.to(device)

        if args.attribute_classifier_ckpt is not None:
            print("load attribute clasifier:", args.attribute_classifier_ckpt)
            classifier_checkpoint = torch.load(args.attribute_classifier_ckpt)
            self.attribute_classifier.load_state_dict(
                classifier_checkpoint['state_dict'])

        self.optimizer = torch.optim.AdamW(
            self.generator.attribute_mapper.parameters(), lr=args.lr)

        self.classifier_criterion = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

        self.sample_z = None  # Sample for validation
        self.mean_latent = None  # Mean for truncation

        cudnn.benchmark = True
        self.writer = SummaryWriter(log_dir=str(args.log_dir))
示例#11
0
    
    if args.gif and args.num_samples < 2:
        raise ValueError('for GIF num_samples must be greater than 1')

    if not args.out_path:
        out_path = 'DCGAN-Anime-Faces'
    else:
        out_path = args.out_path

    if args.device == 'cuda':
        if torch.cuda.is_available():
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    gen = Generator(128, 3, 64)
    load_gen(gen, args.path_ckpt, device)
    gen.eval()

    if args.grid:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        print("==> Generate IMAGE GRID...")
        output = gen(noise)
        show_batch(output, out_path, num_samples=args.num_samples, figsize=(args.img_size, args.img_size))
    elif args.gif:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        print("==> Generate GIF...")
        images = latent_space_interpolation_sequence(noise, step_interpolation=args.steps)
        output = gen(images)
        if args.resize and isinstance(args.resize, int):
            print(f"==> Resize images to {args.resize}px")
示例#12
0
def main(args):
    # set which gpu(s) to use, should set PCI_BUS_ID first
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    num_gpus = (len(args.gpu) + 1) // 2

    # create model directories
    checkpath(args.modelG_path)
    checkpath(args.modelD_path)

    # tensorboard writer
    checkpath(args.log_path)
    writer = SummaryWriter(args.log_path)

    # load data
    data_loader, num_train = get_loader(args,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        training=True)
    data_loader_val, num_test = get_loader(args,
                                           batch_size=args.val_bs,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           training=False)
    print('Finished data loading')
    print("The length of the train set is: {}".format(num_train))
    print("The length of the test set is: {}".format(num_test))

    colorguide = True
    if args.nocolor:
        colorguide = False

    # loss multipliers
    lambdas = [
        args.lambda_imgl1, args.lambda_wfl1, args.lambda_ssim,
        args.lambda_color
    ]
    lambda_perceptual = args.lambda_perceptual

    # Generator
    netG = Generator(lambdas=lambdas,
                     colorguide=colorguide,
                     input_nc=1,
                     output_nc=1)

    if num_gpus > 1:
        # multi-gpu training with synchonized batchnormalization
        # make sure enough number of gpus are available
        assert (torch.cuda.device_count() >= num_gpus)
        # since we have set CUDA_VISIBLE_DEVICES to avoid some invalid device id issues
        netG = DataParallelWithCallback(
            netG, device_ids=[i for i in range(num_gpus)])
        netG_single = netG.module
    else:
        # single gpu training
        netG_single = netG

    # Discriminator
    netD = NLayerDiscriminator(input_nc=4, n_layers=4)
    if num_gpus > 1:
        netD = DataParallelWithCallback(
            netD, device_ids=[i for i in range(num_gpus)])
        netD_single = netD.module
    else:
        netD_single = netD

    # print(netG_single)
    # print(netD_single)

    if args.pretrained and args.netG_path != '' and args.netD_path != '':
        netG_single.load_state_dict(torch.load(args.netG_path))
        netD_single.load_state_dict(torch.load(args.netD_path))

    # Right now we only support gpu training
    if torch.cuda.is_available():
        netG = netG.cuda()
        netD = netD.cuda()

    # define the perceptual loss, place outside the forward func in G for better multi-gpu training
    Ploss = PNet()
    if num_gpus > 1:
        Ploss = DataParallelWithCallback(
            Ploss, device_ids=[i for i in range(num_gpus)])

    if torch.cuda.is_available():
        Ploss = Ploss.cuda()

    # setup optimizer
    lr = args.learning_rate
    optimizerD = optim.Adam(netD_single.parameters(),
                            lr=lr,
                            betas=(args.beta1, 0.999))
    schedulerD = ReduceLROnPlateau(optimizerD,
                                   factor=0.7,
                                   patience=10,
                                   mode='min',
                                   min_lr=1e-06)
    optimizerG = optim.Adam(netG_single.parameters(),
                            lr=lr,
                            betas=(args.beta1, 0.999))
    schedulerG = ReduceLROnPlateau(optimizerG,
                                   factor=0.7,
                                   patience=10,
                                   mode='min',
                                   min_lr=1e-06)

    for epoch in range(args.num_epochs):
        # switch to train mode
        netG.train()
        netD.train()

        for i, (img_real, wf_real, color_real) in enumerate(data_loader, 0):
            img_real = img_real.cuda()
            wf_real = wf_real.cuda()
            color_real = color_real.cuda()

            # Update D network, we freeze parameters in G to save memory
            for p in netG_single.parameters():
                p.requires_grad = False
            for p in netD_single.parameters():
                p.requires_grad = True

            # if using TTUR, D can be trained multiple steps per G step
            for _ in range(args.D_steps):
                optimizerD.zero_grad()

                # train with real
                real_AB = torch.cat((img_real, wf_real), 1)
                errD_real = 0.5 * netD(trainG=False,
                                       trainReal=True,
                                       real_AB=real_AB,
                                       fake_AB=None).sum()
                errD_real.backward()

                # train with fake
                img_fake, wf_fake, _, _, _, _, _ = netG(trainG=False,
                                                        img_real=None,
                                                        wf_real=wf_real,
                                                        color_real=color_real)
                fake_AB = torch.cat((img_fake, wf_fake), 1)
                errD_fake = 0.5 * netD(trainG=False,
                                       trainReal=False,
                                       real_AB=None,
                                       fake_AB=fake_AB).sum()
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                del img_fake, wf_fake, fake_AB, real_AB, errD_real, errD_fake

            iterations_before_epoch = epoch * len(data_loader)
            writer.add_scalar('D Loss', errD.item(),
                              iterations_before_epoch + i)
            del errD

            # Update G network, we freeze parameters in D to save memory
            for p in netG.parameters():
                p.requires_grad = True
            for p in netD.parameters():
                p.requires_grad = False

            optimizerG.zero_grad()

            img_fake, wf_fake, lossG, wf_ssim, img_l1, color_l1, wf_l1 = netG(
                trainG=True,
                img_real=img_real,
                wf_real=wf_real,
                color_real=color_real)
            ploss = Ploss(img_fake, img_real.detach()).sum()
            fake_AB = torch.cat((img_fake, wf_fake), 1)
            lossD = netD(trainG=True,
                         trainReal=False,
                         real_AB=None,
                         fake_AB=fake_AB).sum()
            errG = (lossG.sum() + lambda_perceptual * ploss + lossD)
            errG.backward()
            optimizerG.step()

            del color_real, fake_AB, lossG, errG

            if args.nocolor:
                print(
                    'Epoch: [{}/{}] Iter: [{}/{}] PercLoss : {:.4f} ImageL1 : {:.6f} WfL1 : {:.6f} WfSSIM : {:.6f}'
                    .format(epoch, args.num_epochs, i, len(data_loader),
                            ploss.item(),
                            img_l1.sum().item(),
                            wf_l1.sum().item(),
                            num_gpus + wf_ssim.sum().item()))
            else:
                print(
                    'Epoch: [{}/{}] Iter: [{}/{}] PercLoss : {:.4f} ImageL1 : {:.6f} WfL1 : {:.6f} WfSSIM : {:.6f} ColorL1 : {:.6f}'
                    .format(epoch, args.num_epochs, i, len(data_loader),
                            ploss.item(),
                            img_l1.sum().item(),
                            wf_l1.sum().item(),
                            num_gpus + wf_ssim.sum().item(),
                            color_l1.sum().item()))
                writer.add_scalar('Color Loss',
                                  color_l1.sum().item(),
                                  iterations_before_epoch + i)

            # tensorboard log
            writer.add_scalar('G Loss', lossD.item(),
                              iterations_before_epoch + i)
            writer.add_scalar('Image L1 Loss',
                              img_l1.sum().item(), iterations_before_epoch + i)
            writer.add_scalar('Wireframe MSSSIM Loss',
                              num_gpus + wf_ssim.sum().item(),
                              iterations_before_epoch + i)
            writer.add_scalar('Wireframe L1',
                              wf_l1.sum().item(), iterations_before_epoch + i)
            writer.add_scalar('Image Perceptual Loss', ploss.item(),
                              iterations_before_epoch + i)

            del wf_ssim, ploss, img_l1, color_l1, wf_l1, lossD

            with torch.no_grad():
                # show generated tarining images in tensorboard
                if i % args.val_freq == 0:
                    real_img = vutils.make_grid(
                        img_real.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Real Image', real_img,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    real_wf = vutils.make_grid(
                        wf_real.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Real Wireframe', real_wf,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    fake_img = vutils.make_grid(
                        img_fake.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Fake Image', fake_img,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    fake_wf = vutils.make_grid(
                        wf_fake.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Fake Wireframe', fake_wf,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    del real_img, real_wf, fake_img, fake_wf

            del img_real, wf_real, img_fake, wf_fake

        # do checkpointing
        if epoch % args.save_freq == 0 and epoch > 0:
            torch.save(netG_single.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(args.modelG_path, epoch))
            torch.save(netD_single.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(args.modelD_path, epoch))

        # validation
        with torch.no_grad():
            netG_single.eval()
            # since we use a realtively large validation batchsize, we don't go through the who test set
            (img_real, wf_real, color_real) = next(iter(data_loader_val))
            img_real = img_real.cuda()
            wf_real = wf_real.cuda()
            color_real = color_real.cuda()

            img_fake, wf_fake, _, _, _, _, _ = netG_single(
                trainG=False,
                img_real=None,
                wf_real=wf_real,
                color_real=color_real)

            # update lr based on the validation perceptual loss
            val_score = Ploss(img_fake.detach(), img_real.detach()).sum()
            schedulerG.step(val_score)
            schedulerD.step(val_score)
            print('Current lr: {:.6f}'.format(
                optimizerG.param_groups[0]['lr']))

            real_img = vutils.make_grid(img_real.detach()[:args.val_size],
                                        normalize=True,
                                        scale_each=True)
            writer.add_image('Test: Real Image', real_img, epoch)
            real_wf = vutils.make_grid(wf_real.detach()[:args.val_size],
                                       normalize=True,
                                       scale_each=True)
            writer.add_image('Test: Real Wireframe', real_wf, epoch)
            fake_img = vutils.make_grid(img_fake.detach()[:args.val_size],
                                        normalize=True,
                                        scale_each=True)
            writer.add_image('Test: Fake Image', fake_img, epoch)
            fake_wf = vutils.make_grid(wf_fake.detach()[:args.val_size],
                                       normalize=True,
                                       scale_each=True)
            writer.add_image('Test: Fake Wireframe', fake_wf, epoch)

            netG_single.train()

            del img_real, real_img, wf_real, real_wf, img_fake, fake_img, wf_fake, fake_wf

    # close tb writer
    writer.close()
示例#13
0
def main(train_set, learning_rate, n_epochs, beta_0, beta_1, batch_size,
         num_workers, hidden_size, model_file, cuda, display_result_every,
         checkpoint_interval, seed, label_smoothing, grad_clip, dropout,
         upsampling):

    #  make data between -1 and 1
    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.ImageFolder(root=os.path.join(
        os.getcwd(), train_set),
                                         transform=data_transform)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  drop_last=True)

    # initialize model
    if model_file:
        try:
            total_examples, fixed_noise, gen_losses, disc_losses, gen_loss_per_epoch, \
            disc_loss_per_epoch, prev_epoch, gen, disc = load_model(model_file, hidden_size, upsampling, cuda)
            print('model loaded successfully!')
        except:
            print('could not load model! creating new model...')
            model_file = None

    if not model_file:
        print('creating new model...')
        if upsampling == 'transpose':
            from models.model import Generator, Discriminator
        elif upsampling == 'nn':
            from models.model_nn import Generator, Discriminator
        elif upsampling == 'bilinear':
            from models.model_bilinear import Generator, Discriminator

        gen = Generator(hidden_dim=hidden_size, leaky=0.2, dropout=dropout)
        disc = Discriminator(leaky=0.2, dropout=dropout)

        gen.weight_init(mean=0, std=0.02)
        disc.weight_init(mean=0, std=0.02)

        total_examples = 0
        disc_losses = []
        gen_losses = []
        disc_loss_per_epoch = []
        gen_loss_per_epoch = []
        prev_epoch = 0

        #  Sample minibatch of m noise samples from noise prior p_g(z) and transform
        if cuda:
            fixed_noise = Variable(torch.randn(9, hidden_size).cuda())
        else:
            fixed_noise = Variable(torch.rand(9, hidden_size))

    if cuda:
        gen.cuda()
        disc.cuda()

    # Binary Cross Entropy loss
    BCE_loss = nn.BCELoss()

    # Adam optimizer
    gen_optimizer = optim.Adam(gen.parameters(),
                               lr=learning_rate,
                               betas=(beta_0, beta_1),
                               eps=1e-8)
    disc_optimizer = optim.Adam(disc.parameters(),
                                lr=learning_rate,
                                betas=(beta_0, beta_1),
                                eps=1e-8)

    # results save folder
    gen_images_dir = 'results/generated_images'
    train_summaries_dir = 'results/training_summaries'
    checkpoint_dir = 'results/checkpoints'
    if not os.path.isdir('results'):
        os.mkdir('results')
    if not os.path.isdir(gen_images_dir):
        os.mkdir(gen_images_dir)
    if not os.path.isdir(train_summaries_dir):
        os.mkdir(train_summaries_dir)
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    np.random.seed(
        seed
    )  # reset training seed to ensure that batches remain the same between runs!

    try:
        for epoch in range(prev_epoch, n_epochs):
            disc_losses_epoch = []
            gen_losses_epoch = []
            for idx, (true_batch, _) in enumerate(train_dataloader):
                disc.zero_grad()

                #  hack 6 of https://github.com/soumith/ganhacks
                if label_smoothing:
                    true_target = torch.FloatTensor(batch_size).uniform_(
                        0.7, 1.2)
                else:
                    true_target = torch.ones(batch_size)

                #  Sample  minibatch  of examples from data generating distribution
                if cuda:
                    true_batch = Variable(true_batch.cuda())
                    true_target = Variable(true_target.cuda())
                else:
                    true_batch = Variable(true_batch)
                    true_target = Variable(true_target)

                #  train discriminator on true data
                true_disc_result = disc.forward(true_batch)
                disc_train_loss_true = BCE_loss(true_disc_result.squeeze(),
                                                true_target)
                disc_train_loss_true.backward()
                torch.nn.utils.clip_grad_norm(disc.parameters(), grad_clip)

                #  Sample minibatch of m noise samples from noise prior p_g(z) and transform
                if label_smoothing:
                    fake_target = torch.FloatTensor(batch_size).uniform_(
                        0, 0.3)
                else:
                    fake_target = torch.zeros(batch_size)

                if cuda:
                    z = Variable(torch.randn(batch_size, hidden_size).cuda())
                    fake_target = Variable(fake_target.cuda())
                else:
                    z = Variable(torch.randn(batch_size, hidden_size))
                    fake_target = Variable(fake_target)

                #  train discriminator on fake data
                fake_batch = gen.forward(z.view(-1, hidden_size, 1, 1))
                fake_disc_result = disc.forward(fake_batch.detach(
                ))  # detach so gradients not computed for generator
                disc_train_loss_false = BCE_loss(fake_disc_result.squeeze(),
                                                 fake_target)
                disc_train_loss_false.backward()
                torch.nn.utils.clip_grad_norm(disc.parameters(), grad_clip)
                disc_optimizer.step()

                #  compute performance statistics
                disc_train_loss = disc_train_loss_true + disc_train_loss_false
                disc_losses_epoch.append(disc_train_loss.data[0])

                disc_fake_accuracy = 1 - fake_disc_result.mean().data[0]
                disc_true_accuracy = true_disc_result.mean().data[0]

                #  Sample minibatch of m noise samples from noise prior p_g(z) and transform
                if label_smoothing:
                    true_target = torch.FloatTensor(batch_size).uniform_(
                        0.7, 1.2)
                else:
                    true_target = torch.ones(batch_size)

                if cuda:
                    z = Variable(torch.randn(batch_size, hidden_size).cuda())
                    true_target = Variable(true_target.cuda())
                else:
                    z = Variable(torch.rand(batch_size, hidden_size))
                    true_target = Variable(true_target)

                # train generator
                gen.zero_grad()
                fake_batch = gen.forward(z.view(-1, hidden_size, 1, 1))
                disc_result = disc.forward(fake_batch)
                gen_train_loss = BCE_loss(disc_result.squeeze(), true_target)

                gen_train_loss.backward()
                torch.nn.utils.clip_grad_norm(gen.parameters(), grad_clip)
                gen_optimizer.step()
                gen_losses_epoch.append(gen_train_loss.data[0])

                if (total_examples != 0) and (total_examples %
                                              display_result_every == 0):
                    print(
                        'epoch {}: step {}/{} disc true acc: {:.4f} disc fake acc: {:.4f} '
                        'disc loss: {:.4f}, gen loss: {:.4f}'.format(
                            epoch + 1, idx + 1, len(train_dataloader),
                            disc_true_accuracy, disc_fake_accuracy,
                            disc_train_loss.data[0], gen_train_loss.data[0]))

                # Checkpoint model
                total_examples += batch_size
                if (total_examples != 0) and (total_examples %
                                              checkpoint_interval == 0):

                    disc_losses.extend(disc_losses_epoch)
                    gen_losses.extend(gen_losses_epoch)
                    save_checkpoint(total_examples=total_examples,
                                    fixed_noise=fixed_noise,
                                    disc=disc,
                                    gen=gen,
                                    gen_losses=gen_losses,
                                    disc_losses=disc_losses,
                                    disc_loss_per_epoch=disc_loss_per_epoch,
                                    gen_loss_per_epoch=gen_loss_per_epoch,
                                    epoch=epoch,
                                    directory=checkpoint_dir)
                    print("Checkpoint saved!")

                    #  sample images for inspection
                    save_image_sample(batch=gen.forward(
                        fixed_noise.view(-1, hidden_size, 1, 1)),
                                      cuda=cuda,
                                      total_examples=total_examples,
                                      directory=gen_images_dir)
                    print("Saved images!")

                    # save learning curves for inspection
                    save_learning_curve(gen_losses=gen_losses,
                                        disc_losses=disc_losses,
                                        total_examples=total_examples,
                                        directory=train_summaries_dir)
                    print("Saved learning curves!")

            disc_loss_per_epoch.append(np.average(disc_losses_epoch))
            gen_loss_per_epoch.append(np.average(gen_losses_epoch))

            # Save epoch learning curve
            save_learning_curve_epoch(gen_losses=gen_loss_per_epoch,
                                      disc_losses=disc_loss_per_epoch,
                                      total_epochs=epoch + 1,
                                      directory=train_summaries_dir)
            print("Saved learning curves!")

            print('epoch {}/{} disc loss: {:.4f}, gen loss: {:.4f}'.format(
                epoch + 1, n_epochs,
                np.array(disc_losses_epoch).mean(),
                np.array(gen_losses_epoch).mean()))

            disc_losses.extend(disc_losses_epoch)
            gen_losses.extend(gen_losses_epoch)

    except KeyboardInterrupt:
        print("Saving before quit...")
        save_checkpoint(total_examples=total_examples,
                        fixed_noise=fixed_noise,
                        disc=disc,
                        gen=gen,
                        disc_loss_per_epoch=disc_loss_per_epoch,
                        gen_loss_per_epoch=gen_loss_per_epoch,
                        gen_losses=gen_losses,
                        disc_losses=disc_losses,
                        epoch=epoch,
                        directory=checkpoint_dir)
        print("Checkpoint saved!")

        # sample images for inspection
        save_image_sample(batch=gen.forward(
            fixed_noise.view(-1, hidden_size, 1, 1)),
                          cuda=cuda,
                          total_examples=total_examples,
                          directory=gen_images_dir)
        print("Saved images!")

        # save learning curves for inspection
        save_learning_curve(gen_losses=gen_losses,
                            disc_losses=disc_losses,
                            total_examples=total_examples,
                            directory=train_summaries_dir)
        print("Saved learning curves!")
示例#14
0
        "--sampling",
        default="end",
        choices=["end", "full"],
        help="set endpoint sampling method",
    )
    parser.add_argument("ckpt",
                        metavar="CHECKPOINT",
                        help="path to the model checkpoints")

    args = parser.parse_args()

    latent_dim = 512

    ckpt = torch.load(args.ckpt)

    g = Generator(args.size, latent_dim, 8).to(device)
    g.load_state_dict(ckpt["g_ema"])
    g.eval()

    percept = lpips.PerceptualLoss(model="net-lin",
                                   net="vgg",
                                   use_gpu=device.startswith("cuda"))

    distances = []

    n_batch = args.n_sample // args.batch
    resid = args.n_sample - (n_batch * args.batch)
    batch_sizes = [args.batch] * n_batch + [resid]

    with torch.no_grad():
        for batch in tqdm(batch_sizes):
                        help='rampup_kimg.')
    parser.add_argument('--rampdown_kimg',
                        default=10000,
                        type=float,
                        help='rampdown_kimg.')
    # TODO: support conditional inputs

    args = parser.parse_args()
    opts = {k: v for k, v in args._get_kwargs()}

    latent_size = 512
    sigmoid_at_end = args.gan in ['lsgan', 'gan']

    G = Generator(num_channels=3,
                  latent_size=latent_size,
                  resolution=args.target_resol,
                  fmap_max=latent_size,
                  fmap_base=8192,
                  tanh_at_end=False)
    D = Discriminator(num_channels=3,
                      resolution=args.target_resol,
                      fmap_max=latent_size,
                      fmap_base=8192,
                      sigmoid_at_end=sigmoid_at_end)
    print(G)
    print(D)
    data = CelebA()
    noise = RandomNoiseGenerator(latent_size, 'gaussian')
    pggan = PGGAN(G, D, data, noise, opts)
    pggan.train()
    transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    imgs = []

    for imgfile in args.files:
        img = transform(Image.open(imgfile).convert("RGB"))
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)

    with torch.no_grad():
        noise_sample = torch.randn(n_mean_latent, 512, device=device)
        latent_out = g_ema.style(noise_sample)

        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      n_mean_latent)**0.5

    percept = lpips.PerceptualLoss(model="net-lin",
                                   net="vgg",
                                   use_gpu=device.startswith("cuda"))
示例#17
0
def main(args):
    # by default we only consider single gpu inference
    assert (len(args.gpu) == 1)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # load data
    data_loader_val, num_test = get_loader(args,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           training=False)
    print('finished data loading')

    # Generator
    colorguide = True
    if args.nocolor:
        colorguide = False
    netG = Generator(lambdas=None,
                     colorguide=colorguide,
                     input_nc=1,
                     output_nc=1)

    netG.load_state_dict(torch.load(args.model_path))

    if torch.cuda.is_available():
        netG = netG.cuda()

    out_path = args.out_path
    checkpath(out_path)

    predictions_fid_real = []
    predictions_fid_fake = []
    fid_model = InceptionV3().cuda()
    fid_model.eval()
    Perceptual = PNet().cuda()

    avg_ssim = 0
    lpips = 0

    # validate on test set, TODO: test with single color guide image
    with torch.no_grad():
        netG.eval()
        for i, (img_real, wf_real,
                color_real) in enumerate(data_loader_val, 0):
            img_real = img_real.cuda()
            wf_real = wf_real.cuda()
            if colorguide:
                color_real = color_real.cuda()
            # in case we are in the last interation
            batch_size = img_real.size(0)

            img_fake, wf_fake, _, _, _, _, _ = netG(trainG=False,
                                                    img_real=None,
                                                    wf_real=wf_real,
                                                    color_real=color_real)

            ssim_score = ssim(img_real, img_fake).item() * batch_size
            avg_ssim += ssim_score

            lpips += Perceptual(img_real, img_fake) * batch_size

            # TODO: save generated wireframes
            save_singleimages(img_fake, out_path, i * args.batch_size,
                              args.img_size)

            pred_fid_real = fid_model(img_real)[0]
            pred_fid_fake = fid_model(img_fake)[0]
            predictions_fid_real.append(
                pred_fid_real.data.cpu().numpy().reshape(batch_size, -1))
            predictions_fid_fake.append(
                pred_fid_fake.data.cpu().numpy().reshape(batch_size, -1))

        print('SSIM: {:6f}'.format(avg_ssim / num_test))

        print('LPIPS: {:6f}'.format(lpips / num_test))

        predictions_fid_real = np.concatenate(predictions_fid_real, 0)
        predictions_fid_fake = np.concatenate(predictions_fid_fake, 0)
        fid = compute_fid_score(predictions_fid_fake, predictions_fid_real)
        print('FID: {:6f}'.format(fid))
示例#18
0
    if args.gif and args.num_samples < 2:
        raise ValueError('for GIF num_samples must be greater than 1')

    if not args.out_path:
        out_path = 'ProGAN-Anime-Faces'
    else:
        out_path = args.out_path

    if args.device == 'cuda':
        if torch.cuda.is_available():
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    gen = Generator(cfg.Z_DIMENSION, cfg.IN_CHANNELS,
                    cfg.CHANNELS_IMG).to(device)
    load_gen(gen, args.path_ckpt, device)
    gen.eval()
    alpha = 1
    step = int(log2(cfg.START_TRAIN_IMG_SIZE / 4))

    if args.grid:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        print("==> Generate IMAGE GRID...")
        output = gen(noise, alpha, step)
        show_batch(output,
                   out_path,
                   num_samples=args.num_samples,
                   figsize=(args.img_size, args.img_size))
    elif args.gif:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        "--out_prefix",
        type=str,
        default="factor",
        help="filename prefix to result samples",
    )
    parser.add_argument(
        "factor",
        type=str,
        help="name of the closed form factorization result factor file",
    )

    args = parser.parse_args()

    eigvec = torch.load(args.factor)["eigvec"].to(args.device)
    ckpt = torch.load(args.ckpt)
    g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
    g.load_state_dict(ckpt["g_ema"], strict=False)

    trunc = g.mean_latent(4096)

    latent = torch.randn(args.n_sample, 512, device=args.device)
    latent = g.get_latent(latent)

    direction = args.degree * eigvec[:, args.index].unsqueeze(0)

    img, _ = g(
        [latent],
        truncation=args.truncation,
        truncation_latent=trunc,
        input_is_latent=True,
    )
示例#20
0
    parser.add_argument(
        "--inception",
        type=str,
        default=None,
        required=True,
        help="path to precomputed inception embedding",
    )
    parser.add_argument(
        "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
    )

    args = parser.parse_args()

    ckpt = torch.load(args.ckpt)

    g = Generator(args.size, 512, 8).to(device)
    g.load_state_dict(ckpt["g_ema"])
    g = nn.DataParallel(g)
    g.eval()

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g.mean_latent(args.truncation_mean)

    else:
        mean_latent = None

    inception = nn.DataParallel(load_patched_inception_v3()).to(device)
    inception.eval()

    features = extract_feature_from_samples(
示例#21
0
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
from data.dataloader import make_dataloader

ImageFile.LOAD_TRUNCATED_IMAGES = True
opt = TrainOptions().parse()
localtime = time.asctime(time.localtime(time.time()))

start_time = time.time()

train_loader, valid_loader = make_dataloader(opt)

# Decide which device we want to run on
device = torch.device("cuda:0" if (
    torch.cuda.is_available() and opt.ngpu > 0) else "cpu")

# Create the generator
netG = Generator().to(device)
if (device.type == 'cuda') and (opt.ngpu > 1):
    netG = nn.DataParallel(netG, list(range(opt.ngpu)))
netG.apply(weights_init)
print(netG)

# Create the Discriminator
netD = Discriminator().to(device)
if (device.type == 'cuda') and (opt.ngpu > 1):
    netD = nn.DataParallel(netD, list(range(opt.ngpu)))
netD.apply(weights_init)
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()
l1_loss = nn.L1Loss(reduction='sum')
示例#23
0
文件: train.py 项目: DeepSea16/myangn
opt = parser.parse_args()

# dataset and dataloader
path = "D:\\python\\worksapce\\P_github_project_INFOgan_mnist_ten\\data"
dataset = MyDataset(path, opt)
dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 1

netG = Generator(ngpu, nz, ngf, nc).to(device)
netD = Discriminator(ngpu, nc, ndf).to(device)
netQ = QNet(ngpu, nc, ndf).to(device)

criterion = nn.BCELoss()
qnetloss = QNetLoss()

optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerQ = optim.Adam(netQ.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

if __name__ == "__main__":
    netG.apply(weights_init)
    netD.apply(weights_init)

    fixed_z = Variable(torch.randn(opt.batchSize, nz - 10, 1, 1)).to(device)
示例#24
0
    data_range = 10
    batch_size = args.batch_size
    num_epochs = args.max_epoch
    input_dim = 1
    hidden_dim = 32
    output_dim = 1
    num_epochs = 100000
    num_epochs_pre = 500
    learning_rate = 0.03

    # Samples
    data = DataDistribution(mu, sigma)
    gen = NoiseDistribution(data_range)

    # Models
    G = Generator(input_dim, hidden_dim, output_dim)
    D = Discriminator(input_dim, hidden_dim, output_dim)

    # Loss function
    criterion = torch.nn.BCELoss()

    # optimizer
    optimizer = torch.optim.SGD(D.parameters(), lr=learning_rate)

    D_pre_losses = []
    num_samples_pre = 5000
    num_bins_pre = 100
    for epoch in range(num_epochs_pre):
        # Generate samples
        d = data.sample(num_samples_pre)
        histc, edges = np.histogram(d, num_bins_pre, density=True)
示例#25
0
    # visdom setting
    vissum = VisdomSummary(port=FG.vis_port, env=FG.vis_env)

    # Dimensionality of the latent vector.
    latent_size = 512
    # Use sigmoid activation for the last layer?
    sigmoid_at_end = FG.gan in ['lsgan', 'gan']
    if hasattr(FG, 'no_tanh'):
        tanh_at_end = False
    else:
        tanh_at_end = True

    G = Generator(num_channels=3,
                  latent_size=latent_size,
                  resolution=FG.target_resol,
                  fmap_max=latent_size,
                  fmap_base=8192,
                  tanh_at_end=tanh_at_end).to(device)
    D = Discriminator(num_channels=3,
                      mbstat_avg=FG.mbstat_avg,
                      resolution=FG.target_resol,
                      fmap_max=latent_size,
                      fmap_base=8192,
                      sigmoid_at_end=sigmoid_at_end).to(device)

    print(G)
    print(D)
    # exit()

    if len(FG.gpu) != 1:
        G = torch.nn.DataParallel(G, FG.gpu)
示例#26
0
def test(args):
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    transform = [transforms.Resize((256, 256), Image.BICUBIC),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

    test_data_loader = dataset(
                            root_dir = args.DATA.data_path,
                            dataset = args.DATA.dataset,
                            mode = 'test',
                            direction = args.DATA.direction,
                            transform=transform)


    test_loader = DataLoader(test_data_loader,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4)

    "Define Models"
    generator = Generator(args.MODEL.ngf,
                          args.MODEL.input_nc,
                          args.MODEL.output_nc)
    "Load weights"
    checkpoint = args.EVALUATION.evaluation_path
    checkpoint = os.path.join(checkpoint, "best.pth.tar")
    load(generator, checkpoint)
    generator.to(device)
    L1_loss = torch.nn.L1Loss()


    t_loss = 0
    generator.eval()

    for i, batch in enumerate(test_loader):
        input = batch['input'].to(device)
        target = batch['target'].to(device)

        "Discriminator Training"
        with torch.no_grad():
            output = generator.forward(input)
        test_loss = L1_loss(output,target)

        t_loss += test_loss



        if args.EVALUATION.plot:
            "Display Images"
            input = input.detach().cpu().numpy()
            input = input.squeeze().transpose((1, 2, 0))

            output = output.detach().cpu().numpy()
            output = output.squeeze().transpose((1, 2, 0))

            fig = plt.figure()
            ax1 = fig.add_subplot(1, 2, 1)
            ax1.set_title('input')
            ax1.imshow(input)

            ax2 = fig.add_subplot(1, 2, 2)
            ax2.set_title('output')
            ax2.imshow(output)
            plt.pause(1)

        print("Iter:", i, "L1 loss:", test_loss.item())
    print("Final Test Loss: ", t_loss.item()/len(test_loader))
示例#27
0
        device = "cpu"

    print(f"=> Called with args {args.__dict__}")
    print(f"=> Config params {cfg.__dict__}")
    print(f"=> Run on device {device}")
    # define dataset and dataloader
    dataset = AnimeFacesDataset(args.data_path)
    cfg.DATASET_SIZE = len(dataset)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.BATCH_SIZE,
                            shuffle=True,
                            num_workers=2,
                            drop_last=True,
                            pin_memory=True)
    # define models
    gen = Generator(cfg.Z_DIMENSION, cfg.CHANNELS_IMG,
                    cfg.FEATURES_GEN).to(device)
    critic = Critic(cfg.CHANNELS_IMG, cfg.FEATURES_DISC).to(device)

    if args.checkpoint_path:
        opt_gen = optim.Adam(gen.parameters(),
                             lr=cfg.LEARNING_RATE,
                             betas=(0.5, 0.999))
        opt_critic = optim.Adam(critic.parameters(),
                                lr=cfg.LEARNING_RATE,
                                betas=(0.5, 0.999))
        cp = torch.load(args.checkpoint_path)
        start_epoch, end_epoch, fixed_noise = load_checkpoint(
            cp, gen, critic, opt_gen, opt_critic)
        cfg.NUM_EPOCHS = end_epoch
    else:
        print("=> Init default weights of models and fixed noise")
示例#28
0
                                                batch_size=params.batch_size,
                                                shuffle=True)

# Test data
test_data = DatasetFromFolder(data_dir,
                              val_file,
                              params.img_types,
                              transform=transform)
test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                               batch_size=params.batch_size,
                                               shuffle=False)
# test_input, test_target = test_data_loader.__iter__().__next__()

# Models
torch.cuda.set_device(params.gpu)
G = Generator(3, params.ngf, 3)
D = Discriminator(6, params.ndf, 1)
G.cuda()
D.cuda()
G.normal_weight_init(mean=0.0, std=0.02)
D.normal_weight_init(mean=0.0, std=0.02)

slim_params, insnorm_params = [], []
for name, param in G.named_parameters():
    if param.requires_grad and name.endswith(
            'weight') and 'insnorm_conv' in name:
        insnorm_params.append(param)
        if len(slim_params) % 2 == 0:
            slim_params.append(param[:len(param) // 2])
        else:
            slim_params.append(param[len(param) // 2:])
示例#29
0
G_path = os.path.join('./checkpoints', opt.checkpoint, 'weight',
                      'netG_epoch_' + str(opt.startEpoch) + '.pth')
D_path = os.path.join('./checkpoints', opt.checkpoint, 'weight',
                      'netD_epoch_' + str(opt.startEpoch) + '.pth')

print(opt)

driving_video_loader = make_test_dataloader(opt)

# Decide which device we want to run on
device = torch.device("cuda:0" if (
    torch.cuda.is_available() and ngpu > 0) else "cpu")

# Create the generator
netG = Generator(ngf).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
print(netG)

# load  model
netG.load_state_dict(torch.load(G_path))
print('load %s !' % G_path)

# source image
source_image = Image.open(opt.sourceImage).convert('RGB')
source_image_transform = transforms.Compose([
    transforms.Resize((opt.imageSize, opt.imageSize)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
示例#30
0
    if args.fid:
        fid_model = InceptionV3FID(device)
    else:
        fid_model = None

    # defining dataset and dataloader
    dataset = ImgFolderDataset(args.data_path)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.BATCH_SIZE,
                            shuffle=True,
                            num_workers=2,
                            drop_last=True,
                            pin_memory=True)
    # defining models
    gen = Generator(img_size=cfg.IMG_SIZE,
                    in_channels=cfg.IN_CHANNELS,
                    img_channels=cfg.CHANNELS_IMG,
                    z_dim=cfg.Z_DIMENSION).to(device)
    dis = Discriminator(img_size=cfg.IMG_SIZE,
                        img_channels=cfg.CHANNELS_IMG).to(device)
    # defining optimizers
    opt_gen = optim.Adam(params=gen.parameters(),
                         lr=cfg.LEARNING_RATE,
                         betas=(0.0, 0.99))
    opt_dis = optim.Adam(params=dis.parameters(),
                         lr=cfg.LEARNING_RATE,
                         betas=(0.0, 0.99))
    # defining gradient scalers for automatic mixed precision
    scaler_gen = torch.cuda.amp.GradScaler()
    scaler_dis = torch.cuda.amp.GradScaler()

    if args.checkpoint: