예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--ng_ch', type=int, default=64)
    parser.add_argument('--nd_ch', type=int, default=64)
    parser.add_argument('--epoch',
                        type=int,
                        default=50,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    batch_size = opt.batch
    epoch_size = opt.epoch

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    random.seed(0)
    torch.manual_seed(0)

    dataset = dset.SVHN(root='../svhn_root',
                        download=True,
                        transform=transforms.Compose([
                            transforms.Resize(64),
                            transforms.ColorJitter(brightness=0,
                                                   contrast=0,
                                                   saturation=0,
                                                   hue=0.5),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                        ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    nz = int(opt.nz)

    netG = Generator().to(device)
    netG.apply(weights_init)
    print(netG)

    netD = Discriminator().to(device)
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()  # criterion = nn.BCELoss()

    fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)

    beta_dist = torch.distributions.beta.Beta(0.2, 0.2)

    for epoch in range(epoch_size):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)
            sample_size = real_image.size(0)
            noise = torch.randn(sample_size, nz, 1, 1, device=device)

            mix_label = beta_dist.sample((sample_size, )).to(device)
            mix_rate = mix_label.reshape(sample_size, 1, 1, 1)

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()

            fake_image = netG(noise)
            mix_image = mix_rate * real_image + (1. - mix_rate) * fake_image

            output = netD(mix_image)
            errD = criterion(output, mix_label)
            errD.backward()
            optimizerD.step()

            D_mix1 = output.mean().item()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()

            fake_image = netG(noise)
            mix_image = mix_rate * real_image + (1. - mix_rate) * fake_image

            output = netD(mix_image)
            errG = criterion(output, 1. - mix_label)
            errG.backward()
            optimizerG.step()

            D_mix2 = output.mean().item()

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(mix): {:.3f}/{:.3f}'
                .format(epoch + 1, epoch_size, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_mix1, D_mix2))

            if epoch == 0 and itr == 0:
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        fake_image = netG(fixed_noise)
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        # do checkpointing
        if (epoch + 1) % 100 == 0:
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
예제 #2
0
    L2 = torch.sum(criterion(z[:, 0], label.to(device)))
    return L2


#================================optimizer======================================
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


dis = Discriminator().to(device)
dis.apply(weights_init)
gen = Generator(batch_size).to(device)
gen.apply(weights_init)

dis_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
print("setted model/loss/optimizer")

#================================train==========================================
print("start training")
iteration_sum = 0
for epo in range(epoch):

    running_loss_dis = 0.0
    running_loss_gen = 0.0
    iterations = 0
예제 #3
0
                      img_shape=img_shape)
discriminator = Discriminator(batchnorm=batchnorm, img_shape=img_shape)
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))

# put the nets on device - if a cuda gpu is installed it will use it
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator, discriminator = generator.to(device), discriminator.to(device)

# initialize weights from random distribution with mean 0 and std 0.02
generator.apply(weights_init)
discriminator.apply(weights_init)

if batchnorm:
    if not os.path.isdir(ROOT_DIR + "/images-batchnorm"):
        os.mkdir(ROOT_DIR + "/images-batchnorm")
else:
    if not os.path.isdir(ROOT_DIR + "/images"):
        os.mkdir(ROOT_DIR + "/images")

# start training
current_epoch = 0
for epoch in range(opt.n_epochs):
    for i, (inputs, _) in enumerate(dataloader):
        inputs = inputs.to(device)

        # create the labels for the fake and real images
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers', type=int,
                        help='number of data loading workers', default=2)
    parser.add_argument('--batch_size', type=int,
                        default=50, help='input batch size')
    parser.add_argument('--nz', type=int, default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch', type=int, default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1', type=float, default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf', default='./result_lsgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    # 乱数のシード(種)を固定
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    # STL-10のトレーニングデータセットとテストデータセットを読み込む
    trainset = dset.STL10(root='./dataset/stl10_root', download=True, split='train+unlabeled',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(
                                  64, scale=(88 / 96, 1.0), ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(
                                  brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize(
                                  (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ]))   # ラベルを使用しないのでラベルなしを混在した'train+unlabeled'を読み込む
    testset = dset.STL10(root='./dataset/stl10_root', download=True, split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(
                                 64, scale=(88 / 96, 1.0), ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(
                                 brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize(
                                 (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset    # STL-10のトレーニングデータセットとテストデータセットを合わせて訓練データとする

    # 訓練データをセットしたデータローダーを作成する
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                             shuffle=True, num_workers=int(opt.workers))

    # 学習に使用するデバイスを得る。可能ならGPUを使用する
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('device:', device)
    torch.cuda.set_device(1)

    # 生成器G。ランダムベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz, nch_g=opt.nch_g).to(device)
    netG.apply(weights_init)    # weights_init関数で初期化
    print(netG)

    # 識別器D。画像が、元画像か贋作画像かを識別する
    netD = Discriminator(nch_d=opt.nch_d).to(device)
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()    # 損失関数は平均二乗誤差損失

    # オプティマイザ−のセットアップ
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(
        opt.beta1, 0.999), weight_decay=1e-5)  # 識別器D用
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(
        opt.beta1, 0.999), weight_decay=1e-5)  # 生成器G用

    fixed_noise = torch.randn(opt.batch_size, opt.nz,
                              1, 1, device=device)  # 確認用の固定したノイズ

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)     # 元画像
            sample_size = real_image.size(0)    # 画像枚数
            noise = torch.randn(sample_size, opt.nz, 1, 1,
                                device=device)   # 正規分布からノイズを生成

            real_target = torch.full(
                (sample_size,), 1., device=device)     # 元画像に対する識別信号の目標値「1」
            # 贋作画像に対する識別信号の目標値「0」
            fake_target = torch.full((sample_size,), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()    # 勾配の初期化

            output = netD(real_image)   # 識別器Dで元画像に対する識別信号を出力
            errD_real = criterion(output, real_target)  # 元画像に対する識別信号の損失値
            D_x = output.mean().item()

            fake_image = netG(noise)    # 生成器Gでノイズから贋作画像を生成

            output = netD(fake_image.detach())  # 識別器Dで元画像に対する識別信号を出力
            errD_fake = criterion(output, fake_target)  # 贋作画像に対する識別信号の損失値
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake    # 識別器Dの全体の損失
            errD.backward()    # 誤差逆伝播
            optimizerD.step()   # Dのパラメーターを更新

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()    # 勾配の初期化

            output = netD(fake_image)   # 更新した識別器Dで改めて贋作画像に対する識別信号を出力
            # 生成器Gの損失値。Dに贋作画像を元画像と誤認させたいため目標値は「1」
            errG = criterion(output, real_target)
            errG.backward()     # 誤差逆伝播
            D_G_z2 = output.mean().item()

            optimizerG.step()   # Gのパラメータを更新

            print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                  .format(epoch + 1, opt.n_epoch,
                          itr + 1, len(dataloader),
                          errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:     # 初回に元画像を保存する
                vutils.save_image(real_image, '{}/real_samples.png'.format(opt.outf),
                                  normalize=True, nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(fixed_noise)  # 1エポック終了ごとに確認用の贋作画像を生成する
        vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(opt.outf, epoch + 1),
                          normalize=True, nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:   # 50エポックごとにモデルを保存する
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
예제 #5
0
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=int(workers))

# 学習に使用するデバイスを得る。可能ならGPUを使用する
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

# 生成器G。ランダムベクトルから贋作画像を生成する
netG = Generator(nz=nz, nch_g=nch_g).to(device)
netG.apply(weights_init)  # weights_init関数で初期化
print(netG)

# 識別器D。画像が、元画像か贋作画像かを識別する
netD = Discriminator(nch_d=nch_d).to(device)
netD.apply(weights_init)
print(netD)

criterion = nn.MSELoss()  # 損失関数は平均二乗誤差損失

# オプティマイザ−のセットアップ
optimizerD = optim.Adam(netD.parameters(),
                        lr=lr,
                        betas=(beta1, 0.999),
                        weight_decay=1e-5)  # 識別器D用
optimizerG = optim.Adam(netG.parameters(),
                        lr=lr,
                        betas=(beta1, 0.999),
                        weight_decay=1e-5)  # 生成器G用

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)  # 確認用の固定したノイズ
예제 #6
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset',
                        required=True,
                        help='cifar10 | lsun | imagenet | folder | lfw | fake')
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=64,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_gen', type=int, default=512)
    parser.add_argument('--nch_dis', type=int, default=512)
    parser.add_argument('--nepoch',
                        type=int,
                        default=1000,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--ngpu',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument('--gen',
                        default='',
                        help="path to gen (to continue training)")
    parser.add_argument('--dis',
                        default='',
                        help="path to dis (to continue training)")
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')
    parser.add_argument('--manualSeed', type=int, help='manual seed')

    args = parser.parse_args()
    print(args)

    try:
        os.makedirs(args.outf)
    except OSError:
        pass

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", args.manualSeed)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(args.imageSize),
                                       transforms.CenterCrop(args.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif args.dataset == 'lsun':
        dataset = dset.LSUN(root=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(args.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))  # [0, +1] -> [-1, +1]
    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))

    device = torch.device("cuda:0" if args.cuda else "cpu")
    nch_img = 3

    # custom weights initialization called on gen and dis
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            # m.bias.data.normal_(1.0, 0.02)
            # m.bias.data.fill_(0)

    gen = Generator(args.ngpu, args.nz, args.nch_gen, nch_img).to(device)
    gen.apply(weights_init)
    if args.gen != '':
        gen.load_state_dict(torch.load(args.gen))

    dis = Discriminator(args.ngpu, args.nch_dis, nch_img).to(device)
    dis.apply(weights_init)
    if args.dis != '':
        dis.load_state_dict(torch.load(args.dis))

    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    # fixed_z = torch.randn(args.batchSize, args.nz, 1, 1, device=device)
    fixed_z = torch.randn(8 * 8, args.nz, 1, 1, device=device)
    a_label = 0
    b_label = 1
    c_label = 1

    # setup optimizer
    optim_dis = optim.Adam(dis.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))
    optim_gen = optim.Adam(gen.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))

    for epoch in range(args.nepoch):
        for itr, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            dis.zero_grad()
            real_img = data[0].to(device)
            batch_size = real_img.size(0)
            label = torch.full((batch_size, ), b_label, device=device)

            dis_real = dis(real_img)
            loss_dis_real = criterion(dis_real, label)
            loss_dis_real.backward()

            # train with fake
            z = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_img = gen(z)
            label.fill_(a_label)

            dis_fake1 = dis(fake_img.detach())
            loss_dis_fake = criterion(dis_fake1, label)
            loss_dis_fake.backward()

            loss_dis = loss_dis_real + loss_dis_fake
            optim_dis.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            gen.zero_grad()
            label.fill_(c_label)  # fake labels are real for generator cost

            dis_fake2 = dis(fake_img)
            loss_gen = criterion(dis_fake2, label)
            loss_gen.backward()
            optim_gen.step()

            if (itr + 1) % 100 == 0:
                print(
                    '[{}/{}][{}/{}] LossD:{:.4f} LossG:{:.4f} D(x):{:.4f} D(G(z)):{:.4f}/{:.4f}'
                    .format(epoch + 1, args.nepoch, itr + 1, len(dataloader),
                            loss_dis.item(), loss_gen.item(),
                            dis_real.mean().item(),
                            dis_fake1.mean().item(),
                            dis_fake2.mean().item()))
            # loop end iteration

        if epoch == 0:
            vutils.save_image(real_img,
                              '{}/real_samples.png'.format(args.outf),
                              normalize=True)

        fake_img = gen(fixed_z)
        vutils.save_image(fake_img.detach(),
                          '{}/fake_samples_epoch_{:04}.png'.format(
                              args.outf, epoch),
                          normalize=True)

        # do checkpointing
        torch.save(gen.state_dict(),
                   '{}/gen_epoch_{}.pth'.format(args.outf, epoch))
        torch.save(dis.state_dict(),
                   '{}/dis_epoch_{}.pth'.format(args.outf, epoch))
예제 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch_size',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch',
                        type=int,
                        default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result_cgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    trainset = dset.STL10(root='./dataset/stl10_root',
                          download=True,
                          split='train',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(64,
                                                           scale=(88 / 96,
                                                                  1.0),
                                                           ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(brightness=0.05,
                                                     contrast=0.05,
                                                     saturation=0.05,
                                                     hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))  # ラベルを使用するのでunlabeledを含めない
    testset = dset.STL10(root='./dataset/stl10_root',
                         download=True,
                         split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(64,
                                                          scale=(88 / 96, 1.0),
                                                          ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(brightness=0.05,
                                                    contrast=0.05,
                                                    saturation=0.05,
                                                    hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    # 生成器G。ランダムベクトルとラベルを連結したベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz + 10, nch_g=opt.nch_g).to(
        device)  # 入力ベクトルの次元は、ランダムベクトルの次元nzにクラス数10を加算したもの
    netG.apply(weights_init)
    print(netG)

    # 識別器D。画像とラベルを連結したTensorが、元画像か贋作画像かを識別する
    netD = Discriminator(nch=3 + 10, nch_d=opt.nch_d).to(
        device)  # 入力Tensorのチャネル数は、画像のチャネル数3にクラス数10を加算したもの
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()

    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)

    fixed_noise = torch.randn(opt.batch_size, opt.nz, 1, 1, device=device)

    fixed_label = [i for i in range(10)] * (opt.batch_size // 10
                                            )  # 確認用のラベル。0〜9のラベルの繰り返し
    fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)

    fixed_noise_label = concat_noise_label(fixed_noise, fixed_label,
                                           device)  # 確認用のノイズとラベルを連結

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)  # 元画像
            real_label = data[1].to(device)  # 元画像に対応するラベル
            real_image_label = concat_image_label(real_image, real_label,
                                                  device)  # 元画像とラベルを連結

            sample_size = real_image.size(0)
            noise = torch.randn(sample_size, opt.nz, 1, 1, device=device)
            fake_label = torch.randint(10, (sample_size, ),
                                       dtype=torch.long,
                                       device=device)  # 贋作画像生成用のラベル
            fake_noise_label = concat_noise_label(noise, fake_label,
                                                  device)  # ノイズとラベルを連結

            real_target = torch.full((sample_size, ), 1., device=device)
            fake_target = torch.full((sample_size, ), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()

            output = netD(real_image_label)  # 識別器Dで元画像とラベルの組み合わせに対する識別信号を出力
            errD_real = criterion(output, real_target)
            D_x = output.mean().item()

            fake_image = netG(fake_noise_label)  # 生成器Gでラベルに対応した贋作画像を生成
            fake_image_label = concat_image_label(fake_image, fake_label,
                                                  device)  # 贋作画像とラベルを連結

            output = netD(
                fake_image_label.detach())  # 識別器Dで贋作画像とラベルの組み合わせに対する識別信号を出力
            errD_fake = criterion(output, fake_target)
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()

            output = netD(
                fake_image_label)  # 更新した識別器Dで改めて贋作画像とラベルの組み合わせに対する識別信号を出力
            errG = criterion(output, real_target)
            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                .format(epoch + 1, opt.n_epoch, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(
            fixed_noise_label)  # 1エポック終了ごとに、指定したラベルに対応する贋作画像を生成する
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))