예제 #1
0
def main(args):
    chainer.config.user_gpu = args.g
    if args.g >= 0:
        chainer.backends.cuda.get_device_from_id(args.g).use()
        print("GPU mode")

    mnist_train = chainer.datasets.get_mnist()[0]  # MNISTデータ取得
    mnist_train = chainer.dataset.concat_examples(mnist_train)[
        0]  # 画像だけ(ラベルは不要)
    mnist_train = mnist_train.reshape((-1, 1, 28, 28))  # 画像形式(N,C,H,W)に整形
    mnist_iter = iterators.SerialIterator(mnist_train,
                                          args.batchsize,
                                          shuffle=True,
                                          repeat=True)  # iteratorを作成

    generator = Generator(Z_DIM)
    critic = Critic()
    if args.g >= 0:
        generator.to_gpu()
        critic.to_gpu()

    opt_g = optimizers.Adam(args.alpha, args.beta1, args.beta2)
    opt_g.setup(generator)
    opt_c = optimizers.Adam(args.alpha, args.beta1, args.beta2)
    opt_c.setup(critic)

    updater = WGANUpdater(mnist_iter, opt_g, opt_c, args.n_cri, args.gp_lam)
    trainer = Trainer(updater, (args.epoch, "epoch"), out=args.result_dir)
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport(["epoch", "generator/loss", "critic/loss"]))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(
        extensions.PlotReport(("generator/loss", "main/wdist"),
                              "epoch",
                              file_name="loss_plot.eps"))
    trainer.extend(extensions.snapshot_object(
        generator, "model_epoch_{.updater.epoch}.model"),
                   trigger=(10, "epoch"))
    trainer.extend(ext_save_img(generator, args.result_dir + "/out_images"))

    trainer.run()
예제 #2
0
def main(args):
    chainer.config.user_gpu = args.g
    if args.g >= 0:
        chainer.backends.cuda.get_device_from_id(args.g).use()
        print("GPU mode")

    mnist_3 = get_mnist_num(args.neg_numbers)
    mnist_8 = get_mnist_num(args.pos_numbers)

    # iteratorを作成
    kwds = {"batch_size": args.batchsize, "shuffle": True, "repeat": True}
    mnist_3_iter = iterators.SerialIterator(mnist_3, **kwds)
    mnist_8_iter = iterators.SerialIterator(mnist_8, **kwds)

    generator = Generator()
    critic = Critic()
    if args.g >= 0:
        generator.to_gpu()
        critic.to_gpu()

    adam_args = args.alpha, args.beta1, args.beta2
    opt_g = optimizers.Adam(*adam_args)
    opt_g.setup(generator)
    opt_c = optimizers.Adam(*adam_args)
    opt_c.setup(critic)

    updater = WGANUpdater(mnist_3_iter, mnist_8_iter, opt_g, opt_c,
                          args.n_cri1, args.n_cri2, args.gp_lam, args.l1_lam)
    trainer = Trainer(updater, (args.epoch, "epoch"), out=args.result_dir)
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport(["epoch", "generator/loss", "critic/loss"]))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(
        extensions.PlotReport(("generator/loss", "critic/loss"),
                              "epoch",
                              file_name="loss_plot.eps"))
    trainer.extend(
        ext_save_img(generator, mnist_8, args.result_dir + "/out_images"))

    trainer.run()