Пример #1
0
def traverse():
    opt = get_traverse_options()

    from models.dcgan import Generator
    from utils.misc import TensorImageUtils

    utiler = TensorImageUtils()
    in_channels = opt.in_channels
    if opt.data_name == "MNIST":
        in_channels = 1
    dim_z = opt.dim_z
    num_classes = opt.ndlist
    num_categorical_variables = len(num_classes)
    num_continuous_variables = opt.ncz

    device = torch.device("cuda:0")
    netG = Generator(in_channels, dim_z)
    netG.cuda()

    netG.load_state_dict(torch.load(opt.model_path))

    g = NoiseGenerator(dim_z, num_classes, num_continuous_variables)
    z = g.traversal_get(opt.batch_size, opt.cidx, opt.didx, opt.c_range,
                        opt.seed, opt.fixmode)
    # z = g.random_get(100)
    # z = torch.cat(z[:3], dim=-1)
    print(z.size())

    x = netG(z)
    output_name = "{}.png".format(opt.out_name)
    utiler.save_images(x, output_name, nrow=opt.nrow)
    print("Save traversal image in {}".format(output_name))
Пример #2
0
def train(args):
    declare_global_parameter(args)

    netG = Generator(NZ + NO * 5, NC, DIM, IMAGE_SIZE, 0)
    netD = Discriminator(NZ, NC, DIM, IMAGE_SIZE, 0)
    netC = Critic(NZ, NC, DIM, IMAGE_SIZE, 0)
    netE = models.encoder.Encoder(NZ, NC, NO, DIM, IMAGE_SIZE // 2, 0)

    print(netG)
    print(netC)
    print(netD)
    print(netE)

    if CUDA:
        netG.cuda()
        netC.cuda()
        netD.cuda()
        netE.cuda()
        if NGPU > 1:
            netG = nn.DataParallel(netG, device_ids=range(NGPU))
            netC = nn.DataParallel(netC, device_ids=range(NGPU))
            netD = nn.DataParallel(netD, device_ids=range(NGPU))
            netE = nn.DataParallel(netE, device_ids=range(NGPU))

    cudnn.benchmark = True

    optimizerD = optim.Adam(netD.parameters(), lr=LR_D, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=LR_G, betas=(0.5, 0.9))
    optimizerC = optim.Adam(netC.parameters(), lr=LR_D, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=LR_G, betas=(0.5, 0.9))

    # Dataset loader
    transform = tv.transforms.Compose(
        [tv.transforms.Scale(IMAGE_SIZE),
         tv.transforms.ToTensor()])

    if DATASET == 'mnist':
        dataset = tv.datasets.MNIST(DATA_SAVE,
                                    train=True,
                                    transform=transform,
                                    download=True)

    if CUDA:
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=BATCH_SIZE * 2,
                                                 shuffle=True,
                                                 num_workers=NW,
                                                 pin_memory=True,
                                                 drop_last=True)
    else:
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=BATCH_SIZE * 2,
                                                 shuffle=True,
                                                 num_workers=NW,
                                                 drop_last=True)

    for epoch in range(EPOCHS):
        run_epoch(dataloader, netG, netE, netC, netD, optimizerG, optimizerE,
                  optimizerC, optimizerD)
Пример #3
0
def main():

    # init random seed
    init_random_seed(params.manual_seed)
    #check the needed dirs of config
    check_dirs()

    cudnn.benchmark = True
    torch.cuda.set_device(params.gpu_id[0])  #set current device

    print('=== Build model ===')
    #gpu mode
    generator = Generator()
    discriminator = Discriminator()
    generator = nn.DataParallel(generator, device_ids=params.gpu_id).cuda()
    discriminator = nn.DataParallel(discriminator,
                                    device_ids=params.gpu_id).cuda()

    # restore trained model
    if params.generator_restored:
        generator = restore_model(generator, params.generator_restored)
    if params.discriminator_restored:
        discriminator = restore_model(discriminator,
                                      params.discriminator_restored)

    # container of training
    trainer = Trainer(generator, discriminator)

    if params.mode == 'train':
        # data loader
        print('=== Load data ===')
        train_dataloader = get_data_loader(params.dataset)

        print('=== Begin training ===')
        trainer.train(train_dataloader)
        print('=== Generate {} images, saving in {} ==='.format(
            params.num_images, params.save_root))
        trainer.generate_images(params.num_images, params.save_root)
    elif params.mode == 'test':
        if params.generator_restored:
            print('=== Generate {} images, saving in {} ==='.format(
                params.num_images, params.save_root))
            trainer.generate_images(params.num_images, params.save_root)
        else:
            assert False, '[*]load Generator model first!'

    else:
        assert False, "[*]mode must be 'train' or 'test'!"
Пример #4
0
def main():
    args = parse_args()
    # pt_path = config['pretrained']+'_netG.pt'
    pt_path = args.weights
    if not os.path.isfile(pt_path):
        print(f"{pt_path} pt file does not exist.")
        return

    if not os.path.isdir(args.output):
        os.makedirs(args.output, exist_ok=True)

    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and int(config['ngpu']) > 0) else "cpu")

    # Create the generator
    netG = Generator(config).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (int(config['ngpu']) > 1):
        netG = nn.DataParallel(netG, list(range(int(config['ngpu']))))

    netG.load_state_dict(torch.load(pt_path))
    netG.eval()
    # Print the model
    print(netG)

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(args.batch_size,
                              int(config['nz']),
                              1,
                              1,
                              device=device)

    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
        if args.grid:
            grid = vutils.make_grid(fake, padding=2, normalize=True)
            ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
                'cpu', torch.uint8).numpy()
            im = Image.fromarray(ndarr)
            im.save(os.path.join(args.output, "grid.png"))
        else:
            for idx, tensor in enumerate(fake):
                image = util.tensor2im(tensor, normalize=True)
                im = Image.fromarray(image)
                im.save(os.path.join(args.output, f"{idx}.png"))
Пример #5
0
def get_gan(gan_type: GANType,
            device: torch.device,
            n_power_iterations: int = None) -> Tuple[nn.Module, nn.Module]:
    r"""Fetching GAN and moving it to proper device.

    Args:
        -gan_type (GANType): DCGAN or SN-DCGAN.
        -device (torch.device): On which device (eg. GPU) to move models.
        -n_power_iterations (int): Number of iterations for l_2 matrix norm.

    Returns:
        -G (nn.Module): Generator.
        -D (nn.Module): Discriminator.
    """

    if gan_type == GANType.DCGAN:
        G = Generator().to(device)
        D = Discriminator().to(device)
    elif gan_type == GANType.SN_DCGAN:
        G = SNGenerator().to(device)
        D = SNDiscriminator(n_power_iterations).to(device)

    return G, D
Пример #6
0
        paths = [self.fake_path, self.true_path]
        print("Evaluating FID Score.")
        fid_value = calculate_fid_given_paths(paths,
                                              self.batch_size,
                                              device,
                                              2048,  # defaults
                                              8)
        print('FID: ', fid_value)
        return fid_value


if __name__ == '__main__':
    # start evaluation.
    opt = get_traverse_options()
    evaluator = FIDEvaluator(N=5000, batch_size=opt.batch_size)

    # load data
    data = choose_dataset(opt)
    if opt.data_name == "MNIST" or opt.data_name == "fashion":
        opt.in_channels = 1

    # build and load model
    netG = Generator(opt.in_channels, opt.dim_z)
    netG.cuda()
    netG.load_state_dict(torch.load(opt.model_path))
    noiseG = NoiseGenerator(opt.dim_z, opt.ndlist, opt.ncz)
    device = torch.device(
        "cuda" if opt.cuda and torch.cuda.is_available() else "cpu")

    evaluator.evaluate(netG, noiseG, data, device)
Пример #7
0
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from models.dcgan import Generator, Discriminator, weights_init
from train_gan import Trainer


if __name__ == '__main__':
    p2data = "/hdd1/diploma_outputs/outputs"
    dset_name = "train_flowers.hdf5"

    imageSize = 64

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=imageSize, scale=(0.9, 0.95), ratio=(1, 1)),
        transforms.ToTensor()])

    batchSize = 256
    workers = 4

    dset = Text2ImageDataset(os.path.join(p2data, dset_name), split="train", transform=transform, mean=0, std=1)

    gen = Generator()
    discr = Discriminator()

    # gen, discr, type, dataset, lr, diter, vis_screen, save_path, l1_coef, l2_coef,
    # pre_trained_gen,
    # pre_trained_disc, batch_size, num_workers, epochs

    gan_trainer = Trainer(gen, discr, dset, 0.0002, "gan4", "output1", 50, 100, False, False, 64, 4, 30)
    gan_trainer.train(cls=True, spe=125)
Пример #8
0
    device = torch.device("cuda:0" if opt.cuda else "cpu")
    ngpu = int(opt.ngpu)
    nz = int(opt.nz)
    ngf = int(opt.ngf)
    ndf = int(opt.ndf)

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    netG = Generator().to(device)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    netD = Discriminator().to(device)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
    real_label = 1
Пример #9
0
save_epoch_interval = opt.save_epoch_interval
train_D_iter = opt.train_D_iter

# persistence related parameters
writer = SummaryWriter(save_path)
utiler = TensorImageUtils(save_path)
nrow = opt.nrow

# dataset
data = choose_dataset(opt)
in_channels = opt.in_channels

# models
dim_z = opt.dim_z
netD = Discriminator(in_channels, dim_z)
netG = Generator(in_channels, dim_z)
headD = DHead(512)
headQ = QHead(512, num_clsses, num_continuous_variables)

if use_cuda:
    netD.cuda()
    netG.cuda()
    headD.cuda()
    headQ.cuda()

# training config
optimizer_D = optim.Adam([{
    "params": netD.parameters()
}, {
    "params": headD.parameters()
}],
Пример #10
0
def main():
    load_pretrained = False
    if os.path.isfile(os.path.join(config['pretrained'] + '_netG.pt')):
        load_pretrained = True
        netD_path = os.path.join(config['pretrained'] + '_netD.pt')
        netG_path = os.path.join(config['pretrained'] + '_netG.pt')
        current_epoch = int(config['pretrained'].split(os.path.sep)[-1].split("_")[0]) + 1
        current_iter = int(config['pretrained'].split(os.path.sep)[-1].split("_")[1])
        print(current_epoch, current_iter)
        print("pretrained")
    else:
        current_epoch = 0

    dataset = PokemonDataset(dataroot=config['dataroot'],
                            transform=transforms.Compose([
                                transforms.Resize(int(config['image_size'])),
                                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0),
                                transforms.RandomRotation(10),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ]),
                            config=config)
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size= int(config['batch_size']),
                                            shuffle=True, num_workers= int(config['workers']))

    device = torch.device("cuda:0" if (torch.cuda.is_available() and int(config['ngpu']) > 0) else "cpu")

    # Create the generator
    netG = Generator(config).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (int(config['ngpu']) > 1):
        netG = nn.DataParallel(netG, list(range(int(config['ngpu']))))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    if load_pretrained:
        netG.load_state_dict(torch.load(netG_path))
    else:
        netG.apply(weights_init)
    netG.train()

    # Print the model
    print(netG)

    # Create the discriminator
    netD = Discriminator(config).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (int(config['ngpu']) > 1):
        netD = nn.DataParallel(netD, list(range(int(config['ngpu']))))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    # netD.apply(weights_init)
    if load_pretrained:
        netD.load_state_dict(torch.load(netD_path))
    else:
        netD.apply(weights_init)

    netD.train()
    # Print the model
    print(netD)


    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, int(config['nz']), 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 0.9    # GAN tricks #1: label smoothing
    fake_label = 0

    # Setup Adam optimizers for both G and D
    # optimizerD = optim.Adam(netD.parameters(), lr=float(config['netD_lr']), betas=(float(config['beta1']), 0.999))
    optimizerD = optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), lr=float(config['netD_lr']), betas=(float(config['beta1']), 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=float(config['netG_lr']), betas=(float(config['beta1']), 0.999))

    # Training Loop
    num_epochs = int(config['num_epochs'])
    nz = int(config['nz'])
    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    frames = []
    iters = 0
    if load_pretrained:
        iters = current_iter

    print("Starting Training Loop...")
    start_time = time.time()
    # For each epoch
    for epoch in range(current_epoch, num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                end_time = time.time()
                duration = end_time - start_time
                print(f"{duration:.2f}s, [{epoch}/{num_epochs}][{i}/{len(dataloader)}]\tLoss_D: {errD.item():.4f}\t  \
                Loss_G: {errG.item():.4f}\tD(x): {D_x:.4f}\tD(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}")
                start_time = time.time()

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % int(config['save_freq']) == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                grid = vutils.make_grid(fake, padding=2, normalize=True)
                ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
                im = Image.fromarray(ndarr)
                im.save(os.path.join("output", f"epoch{epoch}_iter{iters}.png"))
                frames.append(im)
                torch.save(netD.state_dict(), os.path.join("output", f"{epoch}_{iters}_netD.pt"))
                torch.save(netG.state_dict(), os.path.join("output", f"{epoch}_{iters}_netG.pt"))

            iters += 1

    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join("output", "loss_curve.png"))
    frames[0].save(os.path.join('output', 'animation.gif'), format='GIF', append_images=frames[1:], save_all=True, duration=500, loop=0)
Пример #11
0
def get_models(args, log=True):
    """
    Get models based on configuration.
    """
    # ebm
    if args.dataset in TOY_DSETS:
        if args.mog_comps is not None:
            logp_net = MOG(args.data_dim, args.mog_comps)
        else:
            logp_net = small_mlp_ebm(args.data_dim, args.h_dim)
    elif args.dataset in TAB_DSETS:
        nout = args.num_classes if args.clf else 1
        logp_net = large_mlp_ebm(args.data_dim, nout=nout, weight_norm=False)
    elif args.dataset in ["mnist", "stackmnist"]:
        nout = 10 if args.clf else 1  # note: clf for stackmnist doesn't work
        if args.img_size is not None:
            if args.dataset == "mnist":
                logp_net = DCGANDiscriminator(in_channels=1,
                                              img_size=args.img_size,
                                              nout=nout)
            elif args.dataset == "stackmnist":
                logp_net = DCGANDiscriminator(img_size=args.img_size,
                                              nout=nout)
            else:
                raise ValueError
        else:
            if args.nice:
                logp_net = NICE(args.data_dim, 1000, 5)
            elif args.mog_comps is not None:
                logp_net = MOG(args.data_dim, args.mog_comps)
            else:
                #logp_net = large_mlp_ebm(args.data_dim, nout=nout)
                logp_net = EnergyModel_mnist(3)
    elif args.dataset == "svhn" or args.dataset == "cifar10" or args.dataset == "cifar100":
        if args.dataset == "cifar100":
            nout = 100 if args.clf else 1
        else:
            nout = 10 if args.clf else 1
        norm = args.norm
        if args.resnet:
            logp_net = ResNetDiscriminator(nout=nout)
        elif args.wide_resnet:
            logp_net = wideresnet.Wide_ResNet(depth=28,
                                              widen_factor=2,
                                              num_classes=nout,
                                              norm=norm,
                                              dropout_rate=args.dropout)
        elif args.thicc_resnet:
            logp_net = wideresnet.Wide_ResNet(depth=28,
                                              widen_factor=10,
                                              num_classes=nout,
                                              norm=norm,
                                              dropout_rate=args.dropout)
        else:
            if args.norm == "batch":
                logp_net = BNDCGANDiscriminator(nout=nout)
            else:
                #logp_net = DCGANDiscriminator(nout=nout)
                logp_net = EnergyModel()
    else:
        raise ValueError

    # generator
    if args.generator_type in ["verahmc", "vera"]:
        # pick generator architecture based on dataset
        if args.dataset in TOY_DSETS:
            generator_net = small_mlp_generator(args.noise_dim, args.data_dim,
                                                args.h_dim)
        elif args.dataset in TAB_DSETS:
            generator_net = large_mlp_generator(args.noise_dim,
                                                args.data_dim,
                                                no_final_act=True)
        elif args.dataset in ["mnist", "stackmnist"]:
            if args.img_size is not None:
                if args.dataset == "mnist":
                    generator_net = DCGANGenerator(
                        noise_dim=args.noise_dim,
                        unit_interval=args.unit_interval,
                        out_channels=1,
                        img_size=args.img_size)
                elif args.dataset == "stackmnist":
                    generator_net = DCGANGenerator(
                        noise_dim=args.noise_dim,
                        unit_interval=args.unit_interval,
                        out_channels=3,
                        img_size=args.img_size)
                else:
                    raise ValueError
            else:
                #generator_net = large_mlp_generator(args.noise_dim, args.data_dim, args.unit_interval, args.nice)
                generator_net = Generator_mnist(3)
        elif args.dataset in ["svhn", "cifar10", "cifar100"]:
            if args.resnet:
                assert args.noise_dim == 128
                generator_net = ResNetGenerator(args.unit_interval)
            elif args.wide_resnet:
                assert args.noise_dim == 128
                generator_net = ResNetGenerator(args.unit_interval,
                                                feats=args.g_feats)
            elif args.thicc_resnet:
                assert args.noise_dim == 128
                generator_net = ResNetGenerator(args.unit_interval,
                                                feats=args.g_feats)
            else:
                #generator_net = DCGANGenerator(args.noise_dim, args.unit_interval)
                generator_net = Generator(args.noise_dim)
        else:
            raise ValueError

        # wrap architecture with methods to sample and estimate entropy
        if args.generator_type == "verahmc":
            generator = VERAHMCGenerator(generator_net, args.noise_dim,
                                         args.mcmc_lr)
        elif args.generator_type == "vera":
            generator = VERAGenerator(generator_net, args.noise_dim,
                                      args.post_lr)
        else:
            raise ValueError

    else:
        raise ValueError

    def count_parameters(model):
        """
        Total number of model parameters.
        """
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    if log:
        utils.print_log("logp_net", args)
        utils.print_log(logp_net, args)
        utils.print_log("generator", args)
        utils.print_log(generator, args)
        utils.print_log("{} ebm parameters".format(count_parameters(logp_net)),
                        args)
        utils.print_log(
            "{} generator parameters".format(count_parameters(generator)),
            args)

    if args.clf:
        logp_net = JEM(logp_net)
    return logp_net, generator
Пример #12
0
def main():

    params = parseyaml()

    if params['arch'] == 'Generator':

        device = to_gpu(ngpu=params['n_gpu'])

        if params['image_size'] == 64:

            netG = Generator(ngpu=0, nz=256,
                             ngf=64, nc=64).to(device)

        elif params['image_size'] == 128:

            netG = Generator_128(ngpu=0, nz=256,
                                 ngf=64, nc=64).to(device)

        elif params['image_size'] == 256:

            netG = Generator_256(ngpu=0, nz=256,
                                 ngf=64, nc=64).to(device)

        netG.apply(weights_init)
        netG.load_state_dict(torch.load(params['path']))

        for i in range(params['quantity']):

            fixed_noise = torch.randn(64, 256, 1, 1, device=device)
            fakes = netG(fixed_noise)

            for j in range(len(fakes)):
                save_image(fakes[j], params['out'] + params['run'] +
                           '_' + str(i) + '_' + str(j) + '_img.png')

    else:

        dataloader = dataLoader(
            path=params['path'], image_size=params['image_size'], batch_size=params['batch_size'],
            workers=params['loader_workers'])

        device = to_gpu(ngpu=params['n_gpu'])

        if params['arch'] == 'DCGAN':

            if params['image_size'] == 64:

                netG = Generator(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                 ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator(params['n_gpu'], nc=params['number_channels'],
                                     ndf=params['dis_feature_maps']).to(device)

            elif params['image_size'] == 128:

                netG = Generator_128(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                     ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator_128(params['n_gpu'], nc=params['number_channels'],
                                         ndf=params['dis_feature_maps']).to(device)

            elif params['image_size'] == 256:

                netG = Generator_256(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                     ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator_256(params['n_gpu'], nc=params['number_channels'],
                                         ndf=params['dis_feature_maps']).to(device)

        elif params['arch'] == 'SNGAN':

            if params['image_size'] == 64:

                netG = Generator(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                 ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator_SN(params['n_gpu'], nc=params['number_channels'],
                                        ndf=params['dis_feature_maps']).to(device)

            elif params['image_size'] == 128:

                netG = Generator_128(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                     ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator_SN_128(params['n_gpu'], nc=params['number_channels'],
                                            ndf=params['dis_feature_maps']).to(device)

            elif params['image_size'] == 256:

                netG = Generator_256(ngpu=params['n_gpu'], nz=params['latent_vector'],
                                     ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

                netD = Discriminator_SN_256(params['n_gpu'], nc=params['number_channels'],
                                            ndf=params['dis_feature_maps']).to(device)

        if (device.type == 'cuda') and (params['n_gpu'] > 1):
            netG = nn.DataParallel(netG, list(range(params['n_gpu'])))

        if (device.type == 'cuda') and (params['n_gpu'] > 1):
            netD = nn.DataParallel(netD, list(range(params['n_gpu'])))

        netG.apply(weights_init)
        netD.apply(weights_init)

        print(netG)
        print(netD)

        criterion = nn.BCELoss()

        fixed_noise = torch.randn(params['image_size'],
                                  params['latent_vector'], 1, 1, device=device)

        if params['learning_rate'] >= 1:

            optimizerD = optim.Adam(netD.parameters(), lr=0.0002 * params['learning_rate'], betas=(
                params['beta_adam'], 0.999))
            optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(
                params['beta_adam'], 0.999))

        else:

            optimizerD = optim.Adam(netD.parameters(), lr=params['learning_rate'], betas=(
                params['beta_adam'], 0.999))
            optimizerG = optim.Adam(netG.parameters(), lr=params['learning_rate'], betas=(
                params['beta_adam'], 0.999))

        G_losses, D_losses, img_list, img_list_only = training_loop(num_epochs=params['num_epochs'], dataloader=dataloader,
                                                                    netG=netG, netD=netD, device=device, criterion=criterion, nz=params[
                                                                        'latent_vector'],
                                                                    optimizerG=optimizerG, optimizerD=optimizerD, fixed_noise=fixed_noise, out=params['out'] + params['run'] + '_')

        loss_plot(G_losses=G_losses, D_losses=D_losses, out=params['out'] + params['run'] + '_')

        image_grid(dataloader=dataloader, img_list=img_list,
                   device=device, out=params['out'] + params['run'] + '_')

        compute_metrics(real=next(iter(dataloader)), fakes=img_list_only,
                        size=params['image_size'], out=params['out'] + params['run'] + '_')
Пример #13
0
    transforms.Scale(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:  # Conv weight init
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:  # BatchNorm weight init
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


G = Generator(latent_dim, ngf, channels).cuda(device)
D = Discriminator(ndf, channels).cuda(device)

G.apply(weights_init)
D.apply(weights_init)

dataloader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.ImageFolder(root='C:/Users/우리집/MyWGAN/images',
                                             transform=trans),
    batch_size=batch_size)

Optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr)
Optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr)
# Optimizer_G = torch.optim.Adam(G.parameters(),lr=lr, betas=(beta1, 0.999))
# Optimizer_D = torch.optim.Adam(D.parameters(),lr=lr, betas=(beta1, 0.999))