Exemple #1
0
def main():
    parser = argparse.ArgumentParser(description='WGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=500,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument("--snapshot_interval", "-s", type=int, default=50)
    parser.add_argument("--display_interval", "-d", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dataset", "-ds", type=str, default="mnist")
    parser.add_argument("--n_dimz", "-z", type=int, default=128)
    args = parser.parse_args()

    out = os.path.join(args.out, args.dataset)
    # Networks
    import Network.mnist_net as Network

    gen = Network.DCGANGenerator(n_hidden=args.n_dimz)
    dis = Network.WGANDiscriminator()
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    # Optimizers
    opt_gen = chainer.optimizers.RMSprop(5e-5)
    opt_gen.setup(gen)
    opt_gen.add_hook(chainer.optimizer.GradientClipping(1))

    opt_dis = chainer.optimizers.RMSprop(5e-5)
    opt_dis.setup(dis)
    opt_dis.add_hook(chainer.optimizer.GradientClipping(1))
    opt_dis.add_hook(WeightClipping(0.01))

    #Get dataset
    train, _ = mnist.get_mnist(withlabel=True, ndim=3, scale=1.)
    train = [i[0] for i in train if (i[1] == 1)]  #ラベル1のみを選択

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Trainer
    import Updater
    updater = Updater.WGANUpdater(models=(gen, dis),
                                  iterator=train_iter,
                                  optimizer={
                                      'gen': opt_gen,
                                      'dis': opt_dis
                                  },
                                  n_dis=5,
                                  device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=out)

    snapshot_interval = (args.epoch, 'epoch')
    display_interval = (args.display_interval, 'epoch')

    # Extensions
    trainer.extend(extensions.dump_graph('wasserstein distance'))
    trainer.extend(
        extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}.npz'),
        trigger=(args.epoch, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PlotReport(['wasserstein distance'],
                              'epoch',
                              file_name='distance.png'))
    trainer.extend(
        extensions.PlotReport(['gen/loss'], 'epoch', file_name='loss.png'))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'wasserstein distance', 'gen/loss', 'elapsed_time']),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(Visualize.out_generated_image(gen, dis, 10, 10, args.seed,
                                                 args.out, args.dataset),
                   trigger=display_interval)
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run
    trainer.run()
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(description="WGAN-gp")
    parser.add_argument("--batchsize", "-b", type=int, default=64)
    parser.add_argument("--epoch", type=int, default=500)
    parser.add_argument("--gpu", "-g", type=int, default=0)
    parser.add_argument("--snapshot_interval", "-s", type=int, default=50)
    parser.add_argument("--display_interval", "-d", type=int, default=1)
    parser.add_argument("--n_dimz", "-z", type=int, default=128)
    parser.add_argument("--dataset", "-ds", type=str, default="mnist")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--out", "-o", type=str, default="result")
    parser.add_argument("--resume", '-r', default='')
    args = parser.parse_args()

    #import .py
    import Updater
    import Visualize
    import Network.mnist_net as Network

    #print settings
    print("GPU:{}".format(args.gpu))
    print("max_epoch:{}".format(args.epoch))
    print("Minibatch_size:{}".format(args.batchsize))
    print("Dataset:{}".format(args.dataset))
    print('')
    out = os.path.join(args.out, args.dataset)
    #Set up NN
    gen = Network.DCGANGenerator()
    dis = Network.WGANDiscriminator()

    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    #Make optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.0, beta2=0.9):
        optimizer = chainer.optimizers.Adam(alpha=alpha,
                                            beta1=beta1,
                                            beta2=beta2)
        optimizer.setup(model)
        return optimizer

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    #Get dataset
    train, _ = mnist.get_mnist(withlabel=True, ndim=3, scale=1.)
    train = [i[0] for i in train if (i[1] == 1)]  #ラベル1のみを選択

    #Setup iterator
    train_iter = iterators.SerialIterator(train, args.batchsize)
    #Setup updater
    updater = Updater.WGANUpdater(models=(gen, dis),
                                  iterator=train_iter,
                                  optimizer={
                                      'gen': opt_gen,
                                      'dis': opt_dis
                                  },
                                  n_dis=5,
                                  lam=10,
                                  device=args.gpu)

    #Setup trainer
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=out)
    snapshot_interval = (args.epoch, 'epoch')
    display_interval = (args.display_interval, 'epoch')
    trainer.extend(
        extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}.npz'),
        trigger=(args.epoch, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'gen/loss', 'dis/loss', 'loss_grad',
        'wasserstein_distance', 'elapsed_time'
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(Visualize.out_generated_image(gen, dis, 10, 10, args.seed,
                                                 args.out, args.dataset),
                   trigger=display_interval)

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()