Пример #1
0
        if step % args.D_iter == 0:
            optim_G.zero_grad()
            z = torch.randn(args.batch_size, args.z_dim).to(device)
            loss_G = -net_D(net_G(z)).mean()
            loss_G.backward()
            optim_G.step()

        if step == 0:
            grid = (make_grid(real) + 1) / 2
            train_writer.add_image('real sample', grid)

        if step == 0 or (step + 1) % args.sample_iter == 0:
            fake = net_G(sample_z).cpu()
            grid = (make_grid(fake) + 1) / 2
            valid_writer.add_image('sample', grid, step)
            save_image(grid, os.path.join(log_dir, 'sample', '%d.png' % step))

        if step == 0 or (step + 1) % 10000 == 0:
            torch.save(net_G.state_dict(),
                       os.path.join(log_dir, 'G_%d.pt' % step))
            score, _ = inception_score(valid_dataset, batch_size=64, cuda=True)
            valid_writer.add_scalar('Inception Score', score, step)
            score = fid_score(IgnoreLabelDataset(cifar10),
                              valid_dataset,
                              batch_size=64,
                              cuda=True,
                              normalize=True,
                              r_cache='./.fid_cache/cifar10')
            valid_writer.add_scalar('FID Score', score, step)
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iterations', type=int, default=100000)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--consistency', type=float, default=10)
    parser.add_argument('--warm_up', type=float, default=1000)
    parser.add_argument('--gp_center', type=float, default=1)
    parser.add_argument('--gamma', type=int, default=10)
    parser.add_argument('--name', type=str, default='DCGAN_v2_CR_GP')
    parser.add_argument('--log-dir', type=str, default='log')
    parser.add_argument('--z-dim', type=int, default=128)
    parser.add_argument('--iter-G', type=int, default=3)
    parser.add_argument('--sample-iter', type=int, default=1000)
    parser.add_argument('--sample-size', type=int, default=64)
    args = parser.parse_args()
    log_dir = os.path.join(args.log_dir, args.name)

    device = torch.device('cuda')
    cifar10 = datasets.CIFAR10('./data',
                               train=True,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(cifar10,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             drop_last=True)

    net_G = Generator().to(device)
    net_D = Discriminator().to(device)

    optim_G = optim.Adam(net_G.parameters(), lr=args.lr, betas=(0.5, 0.999))
    optim_D = optim.Adam(net_D.parameters(), lr=args.lr, betas=(0.5, 0.999))

    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))

    real_label = torch.full((args.batch_size, 1), 1).to(device)
    fake_label = torch.full((args.batch_size, 1), 0).to(device)
    label = torch.cat([real_label, fake_label], dim=0)
    criteria = nn.BCEWithLogitsLoss()

    os.makedirs(os.path.join(log_dir, 'sample'), exist_ok=True)
    sample_z = torch.randn(args.sample_size, args.z_dim).to(device)

    valid_dataset = GenerativeDataset(net_G, args.z_dim, 10000, device)
    looper = loop(dataloader)

    consistency_transforms = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.RandomAffine(0, translate=(0.1, 0.1)),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def consistency_transform_func(images):
        images = deepcopy(images)
        for idx, img in enumerate(images):
            images[idx] = consistency_transforms(img)
        return images

    cs_lambda = args.consistency

    with trange(args.iterations, dynamic_ncols=True) as pbar:
        for step in pbar:
            real, _ = next(looper)
            augment_real = consistency_transform_func(real)
            real = real.to(device)
            augment_real = augment_real.to(device)

            # update discriminator
            z = torch.randn(args.batch_size, args.z_dim).to(device)
            with torch.no_grad():
                fake = net_G(z).detach()

            loss_gp = calc_gradient_penalty(net_D, real, fake, args.gp_center)
            loss_cs = ((net_D(real) - net_D(augment_real))**2).sum()

            pred_D = torch.cat([net_D(real), net_D(fake)])

            loss_D = criteria(
                pred_D, label) + (args.gamma * loss_gp) + (cs_lambda * loss_cs)

            train_writer.add_scalar('regularization/weight', cs_lambda,
                                    step + 1)

            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            train_writer.add_scalar('loss', loss_D.item(), step + 1)

            if step % args.iter_G == 0:
                # update generator
                z = torch.randn(args.batch_size, args.z_dim).to(device)
                pred_G = net_D(net_G(z))
                loss_G = criteria(pred_G, real_label)
                optim_G.zero_grad()
                loss_G.backward()
                optim_G.step()
                train_writer.add_scalar('loss/G', loss_G.item(), step + 1)
                pbar.set_postfix(loss_D='%.4f' % loss_D.item(),
                                 loss_G='%.4f' % loss_G.item(),
                                 loss_gp='%.4f' % loss_gp.item())

            if step == 0:
                grid = (make_grid(real[:args.sample_size]) + 1) / 2
                train_writer.add_image('real sample', grid)

            if step == 0 or (step + 1) % args.sample_iter == 0:
                fake = net_G(sample_z).cpu()
                grid = (make_grid(fake) + 1) / 2
                valid_writer.add_image('sample', grid, step)
                save_image(grid,
                           os.path.join(log_dir, 'sample', '%d.png' % step))

            if step == 0 or (step + 1) % 10000 == 0:
                torch.save(net_G.state_dict(),
                           os.path.join(log_dir, 'G_%d.pt' % step))
                score, _ = inception_score(valid_dataset,
                                           batch_size=64,
                                           cuda=True)
                valid_writer.add_scalar('Inception Score', score, step)
                score = fid_score(IgnoreLabelDataset(cifar10),
                                  valid_dataset,
                                  batch_size=64,
                                  cuda=True,
                                  normalize=True,
                                  r_cache='./.fid_cache/cifar10')
                valid_writer.add_scalar('FID Score', score, step)
Пример #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iterations', type=int, default=200000)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--clip', type=float, default=1e-2)
    parser.add_argument('--name', type=str, default='WGAN_v2')
    parser.add_argument('--log-dir', type=str, default='log')
    parser.add_argument('--z-dim', type=int, default=128)
    parser.add_argument('--iter-D', type=int, default=5)
    parser.add_argument('--sample-iter', type=int, default=1000)
    parser.add_argument('--sample-size', type=int, default=64)
    args = parser.parse_args()
    log_dir = os.path.join(args.log_dir, args.name)

    device = torch.device('cuda')
    cifar10 = datasets.CIFAR10('./data',
                               train=True,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(cifar10,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             drop_last=True)

    net_G = Generator(args.z_dim).to(device)
    net_D = Discriminator().to(device)

    optim_G = optim.RMSprop(net_G.parameters(), lr=args.lr)
    optim_D = optim.RMSprop(net_D.parameters(), lr=args.lr)

    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))

    os.makedirs(os.path.join(log_dir, 'sample'), exist_ok=True)
    sample_z = torch.randn(args.sample_size, args.z_dim).to(device)

    valid_dataset = GenerativeDataset(net_G, args.z_dim, 10000, device)
    looper = loop(dataloader)
    with trange(args.iterations, dynamic_ncols=True) as pbar:
        for step in pbar:
            real, _ = next(looper)
            real = real.to(device)

            z = torch.randn(args.batch_size, args.z_dim).to(device)
            with torch.no_grad():
                fake = net_G(z).detach()
            loss = -net_D(real).mean() + net_D(fake).mean()
            optim_D.zero_grad()
            loss.backward()
            optim_D.step()
            train_writer.add_scalar('loss', -loss.item(), step)
            pbar.set_postfix(loss='%.4f' % -loss.item())

            for param in net_D.parameters():
                param.data.clamp_(-args.clip, args.clip)

            if step % args.iter_D == 0:
                z = torch.randn(args.batch_size, args.z_dim).to(device)
                loss_G = -net_D(net_G(z)).mean()
                optim_G.zero_grad()
                loss_G.backward()
                optim_G.step()

            if step == 0:
                grid = (make_grid(real) + 1) / 2
                train_writer.add_image('real sample', grid)

            if step == 0 or (step + 1) % args.sample_iter == 0:
                fake = net_G(sample_z).cpu()
                grid = (make_grid(fake) + 1) / 2
                valid_writer.add_image('sample', grid, step)
                save_image(grid,
                           os.path.join(log_dir, 'sample', '%d.png' % step))

            if step == 0 or (step + 1) % 10000 == 0:
                torch.save(net_G.state_dict(),
                           os.path.join(log_dir, 'G_%d.pt' % step))
                score, _ = inception_score(valid_dataset,
                                           batch_size=64,
                                           cuda=True)
                valid_writer.add_scalar('Inception Score', score, step)
                score = fid_score(IgnoreLabelDataset(cifar10),
                                  valid_dataset,
                                  batch_size=64,
                                  cuda=True,
                                  normalize=True,
                                  r_cache='./.fid_cache/cifar10')
                valid_writer.add_scalar('FID Score', score, step)
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iterations', type=int, default=100000)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--consistency', type=float, default=10)
    parser.add_argument('--name', type=str, default='SNGAN_CR')
    parser.add_argument('--log-dir', type=str, default='log')
    parser.add_argument('--z-dim', type=int, default=128)
    parser.add_argument('--D-iter', type=int, default=5)
    parser.add_argument('--sample-iter', type=int, default=1000)
    parser.add_argument('--sample-size', type=int, default=64)
    args = parser.parse_args()
    log_dir = os.path.join(args.log_dir, args.name)

    device = torch.device('cuda')
    cifar10 = datasets.CIFAR10('./data',
                               train=True,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(cifar10,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             drop_last=True)

    net_G = Generator(args.z_dim).to(device)
    net_D = Discriminator().to(device)

    optim_G = optim.Adam(net_G.parameters(), lr=args.lr, betas=(0.5, 0.999))
    optim_D = optim.Adam(net_D.parameters(), lr=args.lr, betas=(0.5, 0.999))

    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))

    os.makedirs(os.path.join(log_dir, 'sample'), exist_ok=True)
    sample_z = torch.randn(args.sample_size, args.z_dim).to(device)

    valid_dataset = GenerativeDataset(net_G, args.z_dim, 10000, device)
    looper = loop(dataloader)
    consistency_transforms = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.RandomAffine(0, translate=(0.1, 0.1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def consistency_transform_func(images):
        images = deepcopy(images)
        for idx, img in enumerate(images):
            images[idx] = consistency_transforms(img)
        return images

    cs_lambda = args.consistency

    with trange(args.iterations, dynamic_ncols=True) as pbar:
        for step in pbar:
            real, _ = next(looper)
            augment_real = consistency_transform_func(real)
            real = real.to(device)
            augment_real = augment_real.to(device)

            z = torch.randn(args.batch_size, args.z_dim).to(device)
            with torch.no_grad():
                fake = net_G(z).detach()
            loss_real = torch.nn.functional.relu(1 - net_D(real)).mean()
            loss_fake = torch.nn.functional.relu(1 + net_D(fake)).mean()
            loss_cs = ((net_D(real) - net_D(augment_real))**2).sum()
            loss_D = loss_real + loss_fake + cs_lambda * loss_cs
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            train_writer.add_scalar('loss', loss_D.item(), step)
            pbar.set_postfix(loss='%.4f' % loss_D.item())

            if step % args.D_iter == 0:
                optim_G.zero_grad()
                z = torch.randn(args.batch_size, args.z_dim).to(device)
                loss_G = -net_D(net_G(z)).mean()
                loss_G.backward()
                optim_G.step()

            if step == 0:
                grid = (make_grid(real) + 1) / 2
                train_writer.add_image('real sample', grid)

            if step == 0 or (step + 1) % args.sample_iter == 0:
                fake = net_G(sample_z).cpu()
                grid = (make_grid(fake) + 1) / 2
                valid_writer.add_image('sample', grid, step)
                save_image(grid,
                           os.path.join(log_dir, 'sample', '%d.png' % step))

            if step == 0 or (step + 1) % 10000 == 0:
                torch.save(net_G.state_dict(),
                           os.path.join(log_dir, 'G_%d.pt' % step))
                score, _ = inception_score(valid_dataset,
                                           batch_size=64,
                                           cuda=True)
                valid_writer.add_scalar('Inception Score', score, step)
                score = fid_score(IgnoreLabelDataset(cifar10),
                                  valid_dataset,
                                  batch_size=64,
                                  cuda=True,
                                  normalize=True,
                                  r_cache='./.fid_cache/cifar10')
                valid_writer.add_scalar('FID Score', score, step)