def main(args): chainer.config.user_gpu = args.g if args.g >= 0: chainer.backends.cuda.get_device_from_id(args.g).use() print("GPU mode") mnist_train = chainer.datasets.get_mnist()[0] # MNISTデータ取得 mnist_train = chainer.dataset.concat_examples(mnist_train)[ 0] # 画像だけ(ラベルは不要) mnist_train = mnist_train.reshape((-1, 1, 28, 28)) # 画像形式(N,C,H,W)に整形 mnist_iter = iterators.SerialIterator(mnist_train, args.batchsize, shuffle=True, repeat=True) # iteratorを作成 generator = Generator(Z_DIM) critic = Critic() if args.g >= 0: generator.to_gpu() critic.to_gpu() opt_g = optimizers.Adam(args.alpha, args.beta1, args.beta2) opt_g.setup(generator) opt_c = optimizers.Adam(args.alpha, args.beta1, args.beta2) opt_c.setup(critic) updater = WGANUpdater(mnist_iter, opt_g, opt_c, args.n_cri, args.gp_lam) trainer = Trainer(updater, (args.epoch, "epoch"), out=args.result_dir) trainer.extend(extensions.LogReport()) trainer.extend( extensions.PrintReport(["epoch", "generator/loss", "critic/loss"])) trainer.extend(extensions.ProgressBar()) trainer.extend( extensions.PlotReport(("generator/loss", "main/wdist"), "epoch", file_name="loss_plot.eps")) trainer.extend(extensions.snapshot_object( generator, "model_epoch_{.updater.epoch}.model"), trigger=(10, "epoch")) trainer.extend(ext_save_img(generator, args.result_dir + "/out_images")) trainer.run()
def main(args): chainer.config.user_gpu = args.g if args.g >= 0: chainer.backends.cuda.get_device_from_id(args.g).use() print("GPU mode") mnist_3 = get_mnist_num(args.neg_numbers) mnist_8 = get_mnist_num(args.pos_numbers) # iteratorを作成 kwds = {"batch_size": args.batchsize, "shuffle": True, "repeat": True} mnist_3_iter = iterators.SerialIterator(mnist_3, **kwds) mnist_8_iter = iterators.SerialIterator(mnist_8, **kwds) generator = Generator() critic = Critic() if args.g >= 0: generator.to_gpu() critic.to_gpu() adam_args = args.alpha, args.beta1, args.beta2 opt_g = optimizers.Adam(*adam_args) opt_g.setup(generator) opt_c = optimizers.Adam(*adam_args) opt_c.setup(critic) updater = WGANUpdater(mnist_3_iter, mnist_8_iter, opt_g, opt_c, args.n_cri1, args.n_cri2, args.gp_lam, args.l1_lam) trainer = Trainer(updater, (args.epoch, "epoch"), out=args.result_dir) trainer.extend(extensions.LogReport()) trainer.extend( extensions.PrintReport(["epoch", "generator/loss", "critic/loss"])) trainer.extend(extensions.ProgressBar()) trainer.extend( extensions.PlotReport(("generator/loss", "critic/loss"), "epoch", file_name="loss_plot.eps")) trainer.extend( ext_save_img(generator, mnist_8, args.result_dir + "/out_images")) trainer.run()