Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='VideoGAN')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of videos in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset', '-i', required=True,
                        help='Directory list of images files')
    parser.add_argument('--root', default='.',
                        help='Path of images files')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval', type=int, default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=100,
                        help='Interval of displaying log to console')
    parser.add_argument('--video_codecs', default='H264',
                        help='Video Codec')
    parser.add_argument('--video_ext', default='avi',
                        help='Extension of output video files')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))

    gen = Generator()
    dis = Discriminator()

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer
    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    train = VideoDataset(paths=_get_images_paths(args.dataset, args.root))
    print('# data-size: {}'.format(len(train)))
    print('# data-shape: {}'.format(train[0].shape))
    print('')

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    updater = VideoGANUpdater(
        models=(gen, dis),
        iterator=train_iter,
        optimizer={'gen': opt_gen, 'dis': opt_dis},
        device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'gen/loss', 'dis/loss',
    ]), trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(
        out_generated_video(
            gen, dis, 5, args.seed, args.out,
            args.video_codecs, args.video_ext),
        trigger=snapshot_interval)

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()

    if args.gpu >= 0:
        gen.to_cpu()
    chainer.serializers.save_npz(os.path.join(args.out, 'gen.npz'), gen)
Exemplo n.º 2
0
def train(folding_id, inliner_classes, ic):
    cfg = get_cfg_defaults()
    cfg.merge_from_file('configs/mnist.yaml')
    cfg.freeze()
    logger = logging.getLogger("logger")

    zsize = cfg.MODEL.LATENT_SIZE
    #print("zsize: "+str(zsize))
    output_folder = os.path.join('results_' + str(folding_id) + "_" + "_".join([str(x) for x in inliner_classes]))
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs('models', exist_ok=True)

    train_set, _, _ = make_datasets(cfg, folding_id, inliner_classes)

    logger.info("Train set size: %d" % len(train_set))

    G = Generator(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    G.weight_init(mean=0, std=0.02)

    D = Discriminator(channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    D.weight_init(mean=0, std=0.02)

    #LZ=LatentZ(1,1)

    E = Encoder(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    E.weight_init(mean=0, std=0.02)

    if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH:
        ZD = ZDiscriminator_mergebatch(zsize, cfg.TRAIN.BATCH_SIZE)
    else:
        ZD = ZDiscriminator(zsize, cfg.TRAIN.BATCH_SIZE)
    ZD.weight_init(mean=0, std=0.02)

    lr = cfg.TRAIN.BASE_LEARNING_RATE
    #lr=0.0001


    G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    GE_optimizer = optim.Adam(list(E.parameters()) + list(G.parameters()), lr=lr, betas=(0.5, 0.999))
    ZD_optimizer = optim.Adam(ZD.parameters(), lr=lr, betas=(0.5, 0.999))

    BCE_loss = nn.BCELoss()
    sample = torch.randn(64, zsize).view(-1, zsize, 1, 1)

    tracker = LossTracker(output_folder=output_folder)

    #cfg.TRAIN.EPOCH_COUNT
    for epoch in range(100):
        G.train()
        D.train()
        E.train()
        ZD.train()

        epoch_start_time = time.time()
        print(" cfg.TRAIN.BATCH_SIZE :"+str( cfg.TRAIN.BATCH_SIZE))
        data_loader = make_dataloader(train_set, cfg.TRAIN.BATCH_SIZE, torch.cuda.current_device())
        train_set.shuffle()

        #if (epoch + 1) % 30 == 0:
            #G_optimizer.param_groups[0]['lr'] /= 4
            #D_optimizer.param_groups[0]['lr'] /= 4
            #GE_optimizer.param_groups[0]['lr'] /= 4
            #ZD_optimizer.param_groups[0]['lr'] /= 4
            #print("learning rate change!")

        for y, x in data_loader:
            x = x.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)
            
            #print("x: "+str(x.size()))

            y_real_ = torch.ones(x.shape[0])
            y_fake_ = torch.zeros(x.shape[0])
            
            #print(" y r: "+str(y_real_.size()))
            #print("y f: "+str(y_fake_.size()))

            y_real_z = torch.ones(1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])
            y_fake_z = torch.zeros(1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])
            
            #print("y real z: "+str(y_real_z.size()))
            #print("y fake z: "+str(y_fake_z.size()))

            #############################################

            D.zero_grad()

            D_result = D(x).squeeze()
            D_real_loss = BCE_loss(D_result, y_real_)

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z).detach()
            D_result = D(x_fake).squeeze()
            D_fake_loss = BCE_loss(D_result, y_fake_)

            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()

            D_optimizer.step()

            tracker.update(dict(D=D_train_loss))


            #############################################

            G.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z)
            D_result = D(x_fake).squeeze()

            G_train_loss = BCE_loss(D_result, y_real_)

            G_train_loss.backward()
            G_optimizer.step()

            tracker.update(dict(G=G_train_loss))

            #############################################

            ZD.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize)
            z = Variable(z)

            ZD_result = ZD(z).squeeze()
            ZD_real_loss = BCE_loss(ZD_result, y_real_z)

            z,logvar,mu = E(x)
            z=z.squeeze().detach()

            ZD_result = ZD(z).squeeze()
            ZD_fake_loss = BCE_loss(ZD_result, y_fake_z)

            ZD_train_loss = ZD_real_loss + ZD_fake_loss
            ZD_train_loss.backward()

            ZD_optimizer.step()

            tracker.update(dict(ZD=ZD_train_loss))

            # #############################################

            E.zero_grad()
            G.zero_grad()

            #z = E(x)
            #x_d = G(z)

            z,logvar,mu = E(x)
            x_d = G(z)
            
            #print("------------------------ here it goes")
            #print(str(x.size()))
            #print(str(p_x.size()))
            #print(str(z.size()))
            #print(str(x_d.size()))
            #print("------------------------ here it goes")

            ZD_result = ZD(z.squeeze()).squeeze()

            E_train_loss = BCE_loss(ZD_result, y_real_z) * 1.0

            Recon_loss = F.binary_cross_entropy(x_d, x.detach(),reduction="sum")*2.0
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            
            GE_loss=Recon_loss+kl_loss

            (GE_loss+E_train_loss).backward()

            GE_optimizer.step()

            tracker.update(dict(GE=GE_loss,RE=Recon_loss,K=kl_loss, E=E_train_loss))

            # #############################################

        comparison = torch.cat([x, x_d])
        save_image(comparison.cpu(), os.path.join(output_folder, 'reconstruction_' + str(epoch) + '.png'), nrow=x.shape[0])

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        logger.info('[%d/%d] - ptime: %.2f, %s' % ((epoch + 1), 100, per_epoch_ptime, tracker))

        tracker.register_means(epoch)
        tracker.plot()

        with torch.no_grad():
            resultsample = G(sample).cpu()
            save_image(resultsample.view(64,
                                         cfg.MODEL.INPUT_IMAGE_CHANNELS,
                                         cfg.MODEL.INPUT_IMAGE_SIZE,
                                         cfg.MODEL.INPUT_IMAGE_SIZE),
                       os.path.join(output_folder, 'sample_' + str(epoch) + '.png'))

    logger.info("Training finish!... save training results")

    os.makedirs("models", exist_ok=True)

    print("Training finish!... save training results")
    torch.save(G.state_dict(), "models/Gmodel_%d_%d.pkl" %(folding_id, ic))
    torch.save(E.state_dict(), "models/Emodel_%d_%d.pkl" %(folding_id, ic))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers', type=int,
                        help='number of data loading workers', default=2)
    parser.add_argument('--batch_size', type=int,
                        default=50, help='input batch size')
    parser.add_argument('--nz', type=int, default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch', type=int, default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1', type=float, default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf', default='./result_lsgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    # 乱数のシード(種)を固定
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    # STL-10のトレーニングデータセットとテストデータセットを読み込む
    trainset = dset.STL10(root='./dataset/stl10_root', download=True, split='train+unlabeled',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(
                                  64, scale=(88 / 96, 1.0), ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(
                                  brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize(
                                  (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ]))   # ラベルを使用しないのでラベルなしを混在した'train+unlabeled'を読み込む
    testset = dset.STL10(root='./dataset/stl10_root', download=True, split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(
                                 64, scale=(88 / 96, 1.0), ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(
                                 brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize(
                                 (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset    # STL-10のトレーニングデータセットとテストデータセットを合わせて訓練データとする

    # 訓練データをセットしたデータローダーを作成する
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                             shuffle=True, num_workers=int(opt.workers))

    # 学習に使用するデバイスを得る。可能ならGPUを使用する
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('device:', device)
    torch.cuda.set_device(1)

    # 生成器G。ランダムベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz, nch_g=opt.nch_g).to(device)
    netG.apply(weights_init)    # weights_init関数で初期化
    print(netG)

    # 識別器D。画像が、元画像か贋作画像かを識別する
    netD = Discriminator(nch_d=opt.nch_d).to(device)
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()    # 損失関数は平均二乗誤差損失

    # オプティマイザ−のセットアップ
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(
        opt.beta1, 0.999), weight_decay=1e-5)  # 識別器D用
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(
        opt.beta1, 0.999), weight_decay=1e-5)  # 生成器G用

    fixed_noise = torch.randn(opt.batch_size, opt.nz,
                              1, 1, device=device)  # 確認用の固定したノイズ

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)     # 元画像
            sample_size = real_image.size(0)    # 画像枚数
            noise = torch.randn(sample_size, opt.nz, 1, 1,
                                device=device)   # 正規分布からノイズを生成

            real_target = torch.full(
                (sample_size,), 1., device=device)     # 元画像に対する識別信号の目標値「1」
            # 贋作画像に対する識別信号の目標値「0」
            fake_target = torch.full((sample_size,), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()    # 勾配の初期化

            output = netD(real_image)   # 識別器Dで元画像に対する識別信号を出力
            errD_real = criterion(output, real_target)  # 元画像に対する識別信号の損失値
            D_x = output.mean().item()

            fake_image = netG(noise)    # 生成器Gでノイズから贋作画像を生成

            output = netD(fake_image.detach())  # 識別器Dで元画像に対する識別信号を出力
            errD_fake = criterion(output, fake_target)  # 贋作画像に対する識別信号の損失値
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake    # 識別器Dの全体の損失
            errD.backward()    # 誤差逆伝播
            optimizerD.step()   # Dのパラメーターを更新

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()    # 勾配の初期化

            output = netD(fake_image)   # 更新した識別器Dで改めて贋作画像に対する識別信号を出力
            # 生成器Gの損失値。Dに贋作画像を元画像と誤認させたいため目標値は「1」
            errG = criterion(output, real_target)
            errG.backward()     # 誤差逆伝播
            D_G_z2 = output.mean().item()

            optimizerG.step()   # Gのパラメータを更新

            print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                  .format(epoch + 1, opt.n_epoch,
                          itr + 1, len(dataloader),
                          errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:     # 初回に元画像を保存する
                vutils.save_image(real_image, '{}/real_samples.png'.format(opt.outf),
                                  normalize=True, nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(fixed_noise)  # 1エポック終了ごとに確認用の贋作画像を生成する
        vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(opt.outf, epoch + 1),
                          normalize=True, nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:   # 50エポックごとにモデルを保存する
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
Exemplo n.º 4
0
# device = torch.device("cuda:0,1" if opt.cuda else "cpu")
ngpu = int(opt.nGPU)
# ------------------------------------ step 1/5 ------------------------------------

trainTransform = transforms.Compose([transforms.ToTensor()])

validTransform = transforms.Compose([transforms.ToTensor()])

test_data = MyTestDataset(path=opt.test_path,
                          pkfile=opt.testfile,
                          transform=trainTransform)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)

# ------------------------------------ step 2/5 ------------------------------------
n_res_blk = 12
net = Generator(n_res_blk)
print(net)


def _load_ckpt(model, ckpt):
    load_dict = {}
    for name in ckpt:
        load_dict[name.split('module.')[1]] = ckpt[name]
    model.load_state_dict(load_dict, strict=False)
    return model


if opt.load_ckpt is not None:
    # load params
    pretrained_dict = torch.load(opt.load_ckpt)
    net_state_dict = net.state_dict()
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: DCGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=50,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=100,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# n_hidden: {}'.format(args.n_hidden))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    gen = Generator(n_hidden=args.n_hidden)
    dis = Discriminator()

    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(0.0001),
                           'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    if args.dataset == '':
        # Load the CIFAR10 dataset if args.dataset is not specified
        train, _ = chainer.datasets.get_cifar10(withlabel=False, scale=255.)
    else:
        all_files = os.listdir(args.dataset)
        image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
        print('{} contains {} image files'.format(args.dataset,
                                                  len(image_files)))
        train = chainer.datasets\
            .ImageDataset(paths=image_files, root=args.dataset)

    # Setup an iterator
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Setup an updater
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=args.gpu)

    # Setup a trainer
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'gen/loss',
        'dis/loss',
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(out_generated_image(gen, dis, 10, 10, args.seed, args.out),
                   trigger=snapshot_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Exemplo n.º 6
0
def train(folding_id, inliner_classes, ic, cfg):
    logger = logging.getLogger("logger")

    zsize = cfg.MODEL.LATENT_SIZE
    output_folder = os.path.join('results_' + str(folding_id) + "_" + "_".join([str(x) for x in inliner_classes]))
    output_folder = os.path.join(cfg.OUTPUT_FOLDER, output_folder)

    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(os.path.join(cfg.OUTPUT_FOLDER, 'models'), exist_ok=True)

    train_set, _, _ = make_datasets(cfg, folding_id, inliner_classes)

    logger.info("Train set size: %d" % len(train_set))

    G = Generator(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    G.weight_init(mean=0, std=0.02)

    D = Discriminator(channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    D.weight_init(mean=0, std=0.02)

    E = Encoder(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    E.weight_init(mean=0, std=0.02)

    if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH:
        ZD = ZDiscriminator_mergebatch(zsize, cfg.TRAIN.BATCH_SIZE)
    else:
        ZD = ZDiscriminator(zsize, cfg.TRAIN.BATCH_SIZE)
    ZD.weight_init(mean=0, std=0.02)

    lr = cfg.TRAIN.BASE_LEARNING_RATE

    G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    GE_optimizer = optim.Adam(list(E.parameters()) + list(G.parameters()), lr=lr, betas=(0.5, 0.999))
    ZD_optimizer = optim.Adam(ZD.parameters(), lr=lr, betas=(0.5, 0.999))

    BCE_loss = nn.BCELoss()
    sample = torch.randn(64, zsize).view(-1, zsize, 1, 1)

    tracker = LossTracker(output_folder=output_folder)

    for epoch in range(cfg.TRAIN.EPOCH_COUNT):
        G.train()
        D.train()
        E.train()
        ZD.train()

        epoch_start_time = time.time()

        data_loader = make_dataloader(train_set, cfg.TRAIN.BATCH_SIZE, torch.cuda.current_device())
        train_set.shuffle()

        if (epoch + 1) % 30 == 0:
            G_optimizer.param_groups[0]['lr'] /= 4
            D_optimizer.param_groups[0]['lr'] /= 4
            GE_optimizer.param_groups[0]['lr'] /= 4
            ZD_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        for y, x in data_loader:
            x = x.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)

            y_real_ = torch.ones(x.shape[0])
            y_fake_ = torch.zeros(x.shape[0])

            y_real_z = torch.ones(1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])
            y_fake_z = torch.zeros(1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])

            #############################################

            D.zero_grad()

            D_result = D(x).squeeze()
            D_real_loss = BCE_loss(D_result, y_real_)

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z).detach()
            D_result = D(x_fake).squeeze()
            D_fake_loss = BCE_loss(D_result, y_fake_)

            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()

            D_optimizer.step()

            tracker.update(dict(D=D_train_loss))


            #############################################

            G.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z)
            D_result = D(x_fake).squeeze()

            G_train_loss = BCE_loss(D_result, y_real_)

            G_train_loss.backward()
            G_optimizer.step()

            tracker.update(dict(G=G_train_loss))

            #############################################

            ZD.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize)
            z = z.requires_grad_(True)

            ZD_result = ZD(z).squeeze()
            ZD_real_loss = BCE_loss(ZD_result, y_real_z)

            z = E(x).squeeze().detach()

            ZD_result = ZD(z).squeeze()
            ZD_fake_loss = BCE_loss(ZD_result, y_fake_z)

            ZD_train_loss = ZD_real_loss + ZD_fake_loss
            ZD_train_loss.backward()

            ZD_optimizer.step()

            tracker.update(dict(ZD=ZD_train_loss))

            # #############################################

            E.zero_grad()
            G.zero_grad()

            z = E(x)
            x_d = G(z)

            ZD_result = ZD(z.squeeze()).squeeze()

            E_train_loss = BCE_loss(ZD_result, y_real_z) * 1.0

            Recon_loss = F.binary_cross_entropy(x_d, x.detach()) * 2.0

            (Recon_loss + E_train_loss).backward()

            GE_optimizer.step()

            tracker.update(dict(GE=Recon_loss, E=E_train_loss))

            # #############################################

        comparison = torch.cat([x, x_d])
        save_image(comparison.cpu(), os.path.join(output_folder, 'reconstruction_' + str(epoch) + '.png'), nrow=x.shape[0])

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        logger.info('[%d/%d] - ptime: %.2f, %s' % ((epoch + 1), cfg.TRAIN.EPOCH_COUNT, per_epoch_ptime, tracker))

        tracker.register_means(epoch)
        tracker.plot()

        with torch.no_grad():
            resultsample = G(sample).cpu()
            save_image(resultsample.view(64,
                                         cfg.MODEL.INPUT_IMAGE_CHANNELS,
                                         cfg.MODEL.INPUT_IMAGE_SIZE,
                                         cfg.MODEL.INPUT_IMAGE_SIZE),
                       os.path.join(output_folder, 'sample_' + str(epoch) + '.png'))

    logger.info("Training finish!... save training results")

    os.makedirs("models", exist_ok=True)

    print("Training finish!... save training results")
    torch.save(G.state_dict(), os.path.join(cfg.OUTPUT_FOLDER, "models/Gmodel_%d_%d.pkl" %(folding_id, ic)))
    torch.save(E.state_dict(), os.path.join(cfg.OUTPUT_FOLDER, "models/Emodel_%d_%d.pkl" %(folding_id, ic)))
Exemplo n.º 7
0
#独自データセットを読み込む
dataset = CustomDataset(root, data_transform)

# 訓練データをセットしたデータローダーを作成する
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=int(workers))

# 学習に使用するデバイスを得る。可能ならGPUを使用する
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

# 生成器G。ランダムベクトルから贋作画像を生成する
netG = Generator(nz=nz, nch_g=nch_g).to(device)
netG.apply(weights_init)  # weights_init関数で初期化
print(netG)

# 識別器D。画像が、元画像か贋作画像かを識別する
netD = Discriminator(nch_d=nch_d).to(device)
netD.apply(weights_init)
print(netD)

criterion = nn.MSELoss()  # 損失関数は平均二乗誤差損失

# オプティマイザ−のセットアップ
optimizerD = optim.Adam(netD.parameters(),
                        lr=lr,
                        betas=(beta1, 0.999),
                        weight_decay=1e-5)  # 識別器D用
Exemplo n.º 8
0
Arquivo: train.py Projeto: cdefga/gan
def train(batch_size=1, latent_size=100, learning_rate=2e-3, num_epochs=100):
    cuda = torch.cuda.is_available()
    device = 'cuda:0' if cuda else 'cpu'
    dataloader = datamaker(batch_size=batch_size)
    fixed_img = np.random.uniform(-1, 1, size=(batch_size, latent_size))
    fixed_img = torch.from_numpy(fixed_img).float()
    gen_imgs = []

    G = Generator(input_size=latent_size)
    D = Discriminator()
    if cuda:
        print('Using CUDA')
        fixed_img = fixed_img.cuda()
        G.cuda()
        D.cuda()
        


    g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
    d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)

    wandb.watch(G)
    wandb.watch(D)
    for epoch in range(num_epochs):
        D.train()
        G.train()
        for idx, ( real_images, _ ) in enumerate(tqdm(dataloader)):
            if cuda:
                real_images = real_images.cuda()

            batch_size = real_images.size(0)
            real_images = real_images * 2 - 1

            g_loss_value = 0.0
            d_loss_value = 0.0
            for phase in ['discriminator', 'generator']:
                # TRAIN DISCRIMINATOR
                if phase == 'discriminator':
                    # generate fake images from latent vector
                    latent_vector = np.random.uniform(-1, 1, size=(batch_size, latent_size))
                    latent_vector = torch.from_numpy(latent_vector).float()
                    if cuda:
                        latent_vector = latent_vector.cuda()
                    fake_images = G(latent_vector)

                    # compute discriminator loss on real images
                    d_optimizer.zero_grad()
                    d_real = D(real_images)
                    d_real_loss = real_loss(d_real, smooth=True)

                    # compute discriminator loss in fake images
                    d_fake = D(fake_images)
                    d_fake_loss = fake_loss(d_fake)

                    # total loss, backprop, optimize and update weights
                    d_loss = d_real_loss + d_fake_loss
                    d_loss_value = d_loss.item()

                    d_loss.backward()
                    d_optimizer.step()

                # TRAIN GENERATOR
                if phase == 'generator':
                    latent_vector = np.random.uniform(-1, 1, size=(batch_size, latent_size))
                    latent_vector = torch.from_numpy(latent_vector).float()
                    if cuda:
                      latent_vector = latent_vector.cuda()
                    fake_images = G(latent_vector)
                    
                    g_optimizer.zero_grad()
                    d_fake = D(fake_images)
                    g_loss = real_loss(d_fake)
                    g_loss_value = g_loss.item()

                    g_loss.backward()
                    g_optimizer.step()

            if idx % 100 == 0: 
                pass
                wandb.log({ 'G Loss': g_loss_value, 'D Loss': d_loss_value })
        wandb.log({ 'G Epoch Loss': g_loss_value, 'D Epoch Loss': d_loss_value }, step=epoch)
        
        # test performance
        G.eval()
        gen_img = G(fixed_img)
        gen_imgs.append(gen_img)
    
    # dump generated images
    with open('gen_imgs.pkl', 'wb') as f:
        pkl.dump(gen_imgs, f)
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset',
                        required=True,
                        help='cifar10 | lsun | imagenet | folder | lfw | fake')
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=64,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_gen', type=int, default=512)
    parser.add_argument('--nch_dis', type=int, default=512)
    parser.add_argument('--nepoch',
                        type=int,
                        default=1000,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--ngpu',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument('--gen',
                        default='',
                        help="path to gen (to continue training)")
    parser.add_argument('--dis',
                        default='',
                        help="path to dis (to continue training)")
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')
    parser.add_argument('--manualSeed', type=int, help='manual seed')

    args = parser.parse_args()
    print(args)

    try:
        os.makedirs(args.outf)
    except OSError:
        pass

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", args.manualSeed)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(args.imageSize),
                                       transforms.CenterCrop(args.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif args.dataset == 'lsun':
        dataset = dset.LSUN(root=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(args.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))  # [0, +1] -> [-1, +1]
    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))

    device = torch.device("cuda:0" if args.cuda else "cpu")
    nch_img = 3

    # custom weights initialization called on gen and dis
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            # m.bias.data.normal_(1.0, 0.02)
            # m.bias.data.fill_(0)

    gen = Generator(args.ngpu, args.nz, args.nch_gen, nch_img).to(device)
    gen.apply(weights_init)
    if args.gen != '':
        gen.load_state_dict(torch.load(args.gen))

    dis = Discriminator(args.ngpu, args.nch_dis, nch_img).to(device)
    dis.apply(weights_init)
    if args.dis != '':
        dis.load_state_dict(torch.load(args.dis))

    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    # fixed_z = torch.randn(args.batchSize, args.nz, 1, 1, device=device)
    fixed_z = torch.randn(8 * 8, args.nz, 1, 1, device=device)
    a_label = 0
    b_label = 1
    c_label = 1

    # setup optimizer
    optim_dis = optim.Adam(dis.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))
    optim_gen = optim.Adam(gen.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))

    for epoch in range(args.nepoch):
        for itr, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            dis.zero_grad()
            real_img = data[0].to(device)
            batch_size = real_img.size(0)
            label = torch.full((batch_size, ), b_label, device=device)

            dis_real = dis(real_img)
            loss_dis_real = criterion(dis_real, label)
            loss_dis_real.backward()

            # train with fake
            z = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_img = gen(z)
            label.fill_(a_label)

            dis_fake1 = dis(fake_img.detach())
            loss_dis_fake = criterion(dis_fake1, label)
            loss_dis_fake.backward()

            loss_dis = loss_dis_real + loss_dis_fake
            optim_dis.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            gen.zero_grad()
            label.fill_(c_label)  # fake labels are real for generator cost

            dis_fake2 = dis(fake_img)
            loss_gen = criterion(dis_fake2, label)
            loss_gen.backward()
            optim_gen.step()

            if (itr + 1) % 100 == 0:
                print(
                    '[{}/{}][{}/{}] LossD:{:.4f} LossG:{:.4f} D(x):{:.4f} D(G(z)):{:.4f}/{:.4f}'
                    .format(epoch + 1, args.nepoch, itr + 1, len(dataloader),
                            loss_dis.item(), loss_gen.item(),
                            dis_real.mean().item(),
                            dis_fake1.mean().item(),
                            dis_fake2.mean().item()))
            # loop end iteration

        if epoch == 0:
            vutils.save_image(real_img,
                              '{}/real_samples.png'.format(args.outf),
                              normalize=True)

        fake_img = gen(fixed_z)
        vutils.save_image(fake_img.detach(),
                          '{}/fake_samples_epoch_{:04}.png'.format(
                              args.outf, epoch),
                          normalize=True)

        # do checkpointing
        torch.save(gen.state_dict(),
                   '{}/gen_epoch_{}.pth'.format(args.outf, epoch))
        torch.save(dis.state_dict(),
                   '{}/dis_epoch_{}.pth'.format(args.outf, epoch))
Exemplo n.º 10
0
flags.DEFINE_integer('display_interval', 100,
                     'interval of displaying log to console')
flags.DEFINE_float('adam_alpha', 1e-4, 'learning rate')
flags.DEFINE_float('adam_beta1', 0.0, 'beta1 in Adam')
flags.DEFINE_float('adam_beta2', 0.999, 'beta2 in Adam')
flags.DEFINE_integer('n_dis', 1, 'n discrminator train')

mkdir('results')
mkdir('results/tmp')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
INCEPTION_FILENAME = 'inception_score.pkl'
config = FLAGS.__flags

config = {k: FLAGS[k]._value for k in FLAGS}
generator = Generator(**config)
discriminator = Discrminator(**config)
data_set = Cifar10(batch_size=FLAGS.batch_size)

global_step = tf.Variable(0, name="global_step", trainable=False)
increase_global_step = global_step.assign(global_step + 1)
is_training = tf.placeholder(tf.bool, shape=())
z = tf.placeholder(tf.float32,
                   shape=[None, generator.generate_noise().shape[1]])
x_prehat, logprob = generator(z, is_training=is_training)

hmc = HMCSampler((32, 32, 3))
x_hat, hmc_p, delta_lp = hmc.sample(
    lambda x, reuse: discriminator(x, update_collection="NO_OPS"),
    x_prehat,
    reuse_flag=False)
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: DCGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--data_length',
                        '-dl',
                        type=int,
                        default=20,
                        help='length of one data')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=10,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    parser.add_argument('--visualize_interval',
                        type=int,
                        default=100,
                        help='Interval of saving visualized image')
    args = parser.parse_args()

    # Set up a neural network to train
    gen = Generator(n_hidden=args.n_hidden, data_length=args.data_length)
    dis = Discriminator(data_length=args.data_length)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        # def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    train = TrajectoryDataset(data_length=args.data_length)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Set up a trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    visualize_interval = (args.visualize_interval, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'gen/loss',
        'dis/loss',
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(out_generated_image(gen, dis, train, 10, args.seed,
                                       args.out),
                   trigger=visualize_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Exemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch_size',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch',
                        type=int,
                        default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result_cgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    trainset = dset.STL10(root='./dataset/stl10_root',
                          download=True,
                          split='train',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(64,
                                                           scale=(88 / 96,
                                                                  1.0),
                                                           ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(brightness=0.05,
                                                     contrast=0.05,
                                                     saturation=0.05,
                                                     hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))  # ラベルを使用するのでunlabeledを含めない
    testset = dset.STL10(root='./dataset/stl10_root',
                         download=True,
                         split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(64,
                                                          scale=(88 / 96, 1.0),
                                                          ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(brightness=0.05,
                                                    contrast=0.05,
                                                    saturation=0.05,
                                                    hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    # 生成器G。ランダムベクトルとラベルを連結したベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz + 10, nch_g=opt.nch_g).to(
        device)  # 入力ベクトルの次元は、ランダムベクトルの次元nzにクラス数10を加算したもの
    netG.apply(weights_init)
    print(netG)

    # 識別器D。画像とラベルを連結したTensorが、元画像か贋作画像かを識別する
    netD = Discriminator(nch=3 + 10, nch_d=opt.nch_d).to(
        device)  # 入力Tensorのチャネル数は、画像のチャネル数3にクラス数10を加算したもの
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()

    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)

    fixed_noise = torch.randn(opt.batch_size, opt.nz, 1, 1, device=device)

    fixed_label = [i for i in range(10)] * (opt.batch_size // 10
                                            )  # 確認用のラベル。0〜9のラベルの繰り返し
    fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)

    fixed_noise_label = concat_noise_label(fixed_noise, fixed_label,
                                           device)  # 確認用のノイズとラベルを連結

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)  # 元画像
            real_label = data[1].to(device)  # 元画像に対応するラベル
            real_image_label = concat_image_label(real_image, real_label,
                                                  device)  # 元画像とラベルを連結

            sample_size = real_image.size(0)
            noise = torch.randn(sample_size, opt.nz, 1, 1, device=device)
            fake_label = torch.randint(10, (sample_size, ),
                                       dtype=torch.long,
                                       device=device)  # 贋作画像生成用のラベル
            fake_noise_label = concat_noise_label(noise, fake_label,
                                                  device)  # ノイズとラベルを連結

            real_target = torch.full((sample_size, ), 1., device=device)
            fake_target = torch.full((sample_size, ), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()

            output = netD(real_image_label)  # 識別器Dで元画像とラベルの組み合わせに対する識別信号を出力
            errD_real = criterion(output, real_target)
            D_x = output.mean().item()

            fake_image = netG(fake_noise_label)  # 生成器Gでラベルに対応した贋作画像を生成
            fake_image_label = concat_image_label(fake_image, fake_label,
                                                  device)  # 贋作画像とラベルを連結

            output = netD(
                fake_image_label.detach())  # 識別器Dで贋作画像とラベルの組み合わせに対する識別信号を出力
            errD_fake = criterion(output, fake_target)
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()

            output = netD(
                fake_image_label)  # 更新した識別器Dで改めて贋作画像とラベルの組み合わせに対する識別信号を出力
            errG = criterion(output, real_target)
            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                .format(epoch + 1, opt.n_epoch, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(
            fixed_noise_label)  # 1エポック終了ごとに、指定したラベルに対応する贋作画像を生成する
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
Exemplo n.º 13
0
#test_data = MyTestDataset(path=opt.test_path, transform = testTransform)

train_loader = DataLoader(dataset=train_data,
                          batch_size=opt.train_bs,
                          shuffle=True,
                          drop_last=True)
valid_loader = DataLoader(dataset=valid_data,
                          batch_size=opt.valid_bs,
                          shuffle=False)
#test_loader = DataLoader(dataset=test_data, batch_size=opt.test_bs)

# ------------------------------------ step 2/5 -----------------------------------

n_res_blk = 12
net = Generator(n_res_blk)  #.to(device)

print(net)
#-----multi-gpu training---------
# def load_network(network):
#     save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
#     state_dict = torch.load(save_path)
#     # create new OrderedDict that does not contain `module.`
#     from collections import OrderedDict
#     new_state_dict = OrderedDict()
#     for k, v in state_dict.items():
#         namekey = k[7:] # remove `module.`
#         new_state_dict[namekey] = v
#     # load params
#     network.load_state_dict(new_state_dict)
#     return network
Exemplo n.º 14
0
import torch
from net import Discriminator, Generator
from torch.autograd import Variable
from torchvision.utils import save_image
from utils import to_img
z_dimension = 100
batch_size = 64
D = Discriminator()
G = Generator(z_dimension)
D.load_state_dict(torch.load("./model/discriminator.pth"))
G.load_state_dict(torch.load("./model/generator.pth"))
if torch.cuda.is_available():
    D = D.cuda().eval()
    G = G.cuda().eval()
with torch.no_grad():
    z = Variable(torch.randn(batch_size, z_dimension)).cuda()
    fake_img = G(z)
    fake_img = to_img(fake_img.cpu().data)
    save_image(fake_img, "./result/fake_test.png")
Exemplo n.º 15
0
def autosim(args, eng):

    os.makedirs(args.output_dir, exist_ok=True)
    # Set the logger
    utils.set_logger(os.path.join(args.output_dir, 'train.log'))

    copyfile(args.output_dir)

    # Load parameters from json file
    json_path = os.path.join(args.output_dir, 'Params.json')
    assert os.path.isfile(json_path), "No json file found at {}".format(
        json_path)
    params = utils.Params(json_path)

    # Add attributes to params
    params.output_dir = args.output_dir
    params.cuda = torch.cuda.is_available()
    params.restore_from = args.restore_from
    params.numIter = int(params.numIter)
    params.noise_dims = int(params.noise_dims)
    params.gkernlen = int(params.gkernlen)
    params.step_size = int(params.step_size)
    params.gen_ver = int(args.gen_ver)
    params.dime = 1
    if args.wavelength is not None:
        params.wavelength = int(args.wavelength)
    if args.angle is not None:
        params.angle = int(args.angle)
        #build a recorder
    max_recorder = utils.max_recorder()
    params.recorder = max_recorder

    #build tools
    writer = SummaryWriter(log_dir=r'./scan/runs')
    max_recorder = utils.max_recorder()
    params.recorder = max_recorder
    params.writer = writer

    # make directory
    os.makedirs(args.output_dir + '/outputs', exist_ok=True)
    os.makedirs(args.output_dir + '/model', exist_ok=True)
    os.makedirs(args.output_dir + '/figures/histogram', exist_ok=True)
    os.makedirs(args.output_dir + '/figures/deviceSamples', exist_ok=True)
    os.makedirs(args.output_dir + '/figures/deviceSamples_max', exist_ok=True)
    os.makedirs(args.output_dir + '/deg{}_wl{}_gen_ver{}'.format(
        params.angle, params.wavelength, params.gen_ver),
                exist_ok=True)
    # Define the models
    if params.gen_ver == 0:
        generator = Generator0(params)
    else:
        generator = Generator(params)

    # Move to gpu if possible
    if params.cuda:
        generator.cuda()

    # Define the optimizer
    optimizer = torch.optim.Adam(generator.parameters(),
                                 lr=params.lr,
                                 betas=(params.beta1, params.beta2))

    # Define the scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=params.step_size,
                                                gamma=params.gamma)

    # Load model data
    if args.restore_from is not None:
        params.checkpoint = utils.load_checkpoint(restore_from, generator,
                                                  optimizer, scheduler)
        logging.info('Model data loaded')

    #set the timer
    timer = utils.timer()

    # Train the model and save
    if params.numIter != 0:
        logging.info('Start training')
        train(generator, optimizer, scheduler, eng, params)

    # Generate images and save
    logging.info('Start generating devices')
    evaluate(generator, eng, numImgs=500, params=params)

    timer.out()
    writer.close()
Exemplo n.º 16
0
                                         num_workers=opt.n_cpu)

img_shape = (opt.channels, opt.height, opt.width)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)


# create the objects for loss function, two networks and for the two optimizers
batchnorm = True
adversarial_loss = torch.nn.BCELoss()
generator = Generator(batchnorm=batchnorm,
                      latent=opt.latent,
                      img_shape=img_shape)
discriminator = Discriminator(batchnorm=batchnorm, img_shape=img_shape)
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))

# put the nets on device - if a cuda gpu is installed it will use it
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator, discriminator = generator.to(device), discriminator.to(device)

# initialize weights from random distribution with mean 0 and std 0.02
generator.apply(weights_init)
Exemplo n.º 17
0
    params.gkernlen = int(params.gkernlen)
    params.step_size = int(params.step_size)

    if args.wavelength is not None:
        params.wavelength = int(args.wavelength)
    if args.angle is not None:
        params.angle = int(args.angle)

    # make directory
    os.makedirs(args.output_dir + '/outputs', exist_ok=True)
    os.makedirs(args.output_dir + '/model', exist_ok=True)
    os.makedirs(args.output_dir + '/figures/histogram', exist_ok=True)
    os.makedirs(args.output_dir + '/figures/deviceSamples', exist_ok=True)

    # Define the models
    generator = Generator(params)

    # Move to gpu if possible
    if params.cuda:
        generator.cuda()

    # Define the optimizer
    optimizer = torch.optim.Adam(generator.parameters(),
                                 lr=params.lr,
                                 betas=(params.beta1, params.beta2))

    # Define the scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=params.step_size,
                                                gamma=params.gamma)
Exemplo n.º 18
0
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                         shuffle=True, num_workers=opt.n_cpu)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# create the objects for two networks and for the two optimizers
generator = Generator(latent=opt.latent, channels=opt.channels, num_filters=opt.num_filters)
critic = Critic(channels=opt.channels, num_filters=opt.num_filters)
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.learning_rate)
optimizer_C = torch.optim.RMSprop(critic.parameters(), lr=opt.learning_rate)

# put the nets on gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator, critic = generator.to(device), critic.to(device)
generator.apply(weights_init)
critic.apply(weights_init)

if opt.dataset == 'cifar10':
    print(ROOT_DIR + "/cifar")
    if not os.path.isdir(ROOT_DIR + "/cifar"):
        os.mkdir(ROOT_DIR + "/cifar")
elif opt.dataset == 'LSUN':
Exemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: DCGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=50,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', action='store_true', help='Use GPU')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--gen_model',
                        '-r',
                        default='',
                        help='Use pre-trained generator for training')
    parser.add_argument('--dis_model',
                        '-d',
                        default='',
                        help='Use pre-trained discriminator for training')
    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=100,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.

    if args.gpu:
        if args.communicator == 'naive':
            print('Error: \'naive\' communicator does not support GPU.\n')
            exit(-1)
        comm = chainermn.create_communicator(args.communicator)
        device = comm.intra_rank
    else:
        if args.communicator != 'naive':
            print('Warning: using naive communicator '
                  'because only naive supports CPU-only execution')
        comm = chainermn.create_communicator('naive')
        device = -1

    if comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(comm.size))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num hidden unit: {}'.format(args.n_hidden))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    # Set up a neural network to train
    gen = Generator(n_hidden=args.n_hidden)
    dis = Discriminator()

    if device >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(device).use()
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, comm, alpha=0.0002, beta1=0.5):
        # Create a multi node optimizer from a standard Chainer optimizer.
        optimizer = chainermn.create_multi_node_optimizer(
            chainer.optimizers.Adam(alpha=alpha, beta1=beta1), comm)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen, comm)
    opt_dis = make_optimizer(dis, comm)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        if args.dataset == '':
            # Load the CIFAR10 dataset if args.dataset is not specified
            train, _ = chainer.datasets.get_cifar10(withlabel=False,
                                                    scale=255.)
        else:
            all_files = os.listdir(args.dataset)
            image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
            print('{} contains {} image files'.format(args.dataset,
                                                      len(image_files)))
            train = chainer.datasets\
                .ImageDataset(paths=image_files, root=args.dataset)
    else:
        train = None

    train = chainermn.scatter_dataset(train, comm)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Set up a trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        snapshot_interval = (args.snapshot_interval, 'iteration')
        display_interval = (args.display_interval, 'iteration')
        # Save only model parameters.
        # `snapshot` extension will save all the trainer module's attribute,
        # including `train_iter`.
        # However, `train_iter` depends on scattered dataset, which means that
        # `train_iter` may be different in each process.
        # Here, instead of saving whole trainer module, only the network models
        # are saved.
        trainer.extend(extensions.snapshot_object(
            gen, 'gen_iter_{.updater.iteration}.npz'),
                       trigger=snapshot_interval)
        trainer.extend(extensions.snapshot_object(
            dis, 'dis_iter_{.updater.iteration}.npz'),
                       trigger=snapshot_interval)
        trainer.extend(extensions.LogReport(trigger=display_interval))
        trainer.extend(extensions.PrintReport([
            'epoch',
            'iteration',
            'gen/loss',
            'dis/loss',
            'elapsed_time',
        ]),
                       trigger=display_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))
        trainer.extend(out_generated_image(gen, dis, 10, 10, args.seed,
                                           args.out),
                       trigger=snapshot_interval)

    # Start the training using pre-trained model, saved by snapshot_object
    if args.gen_model:
        chainer.serializers.load_npz(args.gen_model, gen)
    if args.dis_model:
        chainer.serializers.load_npz(args.dis_model, dis)

    # Run the training
    trainer.run()
Exemplo n.º 20
0
def train(batch_size, epoch_count, dataset_folder_path, n_hidden, output_path):
    dataset = data_io.dataset_load(dataset_folder_path)
    train_iter = SerialIterator(dataset, batch_size, repeat=True, shuffle=True)

    gen = Generator(n_hidden=n_hidden)
    dis = Discriminator()

    gen.to_gpu(0)
    dis.to_gpu(0)

    opt_gen = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5)
    opt_gen.setup(gen)
    opt_dis = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5)
    opt_dis.setup(dis)

    iteration = 0
    train_iter.reset()

    log_list = []
    image_path = output_path + "image/"
    dis_model_path = output_path + "dis/"
    gen_model_path = output_path + "gen/"
    os.mkdir(output_path)
    os.mkdir(image_path)
    os.mkdir(dis_model_path)
    os.mkdir(gen_model_path)

    for epoch in range(epoch_count):
        d_loss_list = []
        g_loss_list = []
        while True:
            mini_batch_images = train_iter.next()
            mini_batch_images = np.array(mini_batch_images)
            mini_batch_images = (mini_batch_images - 128.0) / 128.0
            x_real = Variable(np.array(mini_batch_images))

            x_real.to_gpu(0)
            y_real = dis(x_real)

            noise = xp.random.uniform(-1,
                                      1, (batch_size, n_hidden),
                                      dtype=np.float32)
            z = Variable(noise)
            x_fake = gen(z, batch_size)
            y_fake = dis(x_fake)

            d_loss = loss_dis(batch_size, y_real, y_fake)
            g_loss = loss_gen(batch_size, y_fake)

            dis.cleargrads()
            d_loss.backward()
            opt_dis.update()

            gen.cleargrads()
            g_loss.backward()
            opt_gen.update()

            d_loss.to_cpu()
            g_loss.to_cpu()

            iteration += batch_size
            d_loss_list.append(d_loss.array)
            g_loss_list.append(g_loss.array)

            if train_iter.is_new_epoch:
                break

        x_fake.to_cpu()
        generated_images = x_fake.array
        generated_images = generated_images.transpose(0, 2, 3, 1)
        Image.fromarray(
            np.clip(generated_images[0] * 255, 0.0, 255.0).astype(
                np.uint8)).save(image_path + str(epoch) + ".png")

        print("epoch: " + str(epoch) + ", interation: " + str(iteration) +
              ", d_loss: " + str(np.mean(d_loss_list)) + ", g_loss: " +
              str(np.mean(g_loss_list)))

        log_json = {
            "epoch": str(epoch),
            "iteration": str(iteration),
            "d_loss": str(np.mean(d_loss_list)),
            "g_loss": str(np.mean(g_loss_list))
        }
        log_list.append(log_json)
        with open(output_path + 'log.json', 'w') as log_file:
            json.dump(log_list, log_file, indent=4)

        if (epoch % 100 == 0):
            dis.to_cpu()
            save_npz(dis_model_path + str(epoch) + '.npz', dis)
            gen.to_cpu()
            save_npz(gen_model_path + str(epoch) + '.npz', gen)
            gen.to_gpu(0)
            dis.to_gpu(0)

    logGraph.save_log_graph(output_path + 'log.json',
                            output_path + "lossGraph.png")
    dis.to_cpu()
    save_npz(dis_model_path + 'last.npz', dis)
    gen.to_cpu()
    save_npz(gen_model_path + 'last.npz', gen)
Exemplo n.º 21
0
        idx.sort(key=distance.__getitem__)
        if args.validDistance == 'COS':
            validationScore += COS[i][idx[0]]
        else:
            validationScore += distance[idx[0]]
        dID[i] = [I[i][idx[j]] for j in range(args.knn)]
        d[svoc[i]] = [tvoc[I[i][idx[j]]] for j in range(args.knn)]

    return d, validationScore / n, dID


# -------------------------------------------------------
# MODEL BUILDING

discriminator = Discriminator(args)
generator = Generator(args)

print("* Loss Initialization")
loss_fn = nn.BCELoss()
print(loss_fn)

zeroClass = Variable(torch.Tensor(args.batchSize).fill_(0),
                     requires_grad=False)
oneClass = Variable(torch.Tensor(args.batchSize).fill_(1), requires_grad=False)
smoothedOneClass = Variable(torch.Tensor(args.batchSize).fill_(1 -
                                                               args.smoothing),
                            requires_grad=False)

if args.gpuid >= 0:
    with torch.cuda.device(args.gpuid):
        generator = generator.cuda()