def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--data_path', type=str, default="data/datasets/cifar10/train", help='dataset directory path') parser.add_argument('--class_name', '-class', type=str, default='all_class', help='class name (default: all_class(str))') parser.add_argument('--num_workers', type=int, default=4, help='number of parallel data loading processes') parser.add_argument('--batchsize', '-b', type=int, default=16) parser.add_argument('--max_iter', '-m', type=int, default=40000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=2500, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=5000, help='Interval of evaluation') parser.add_argument('--out_image_interval', type=int, default=1250, help='Interval of evaluation') parser.add_argument('--stage_interval', type=int, default=40000, help='Interval of stage progress') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=1, help='number of discriminator update per generator update') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--gamma', type=float, default=750, help='gradient penalty') parser.add_argument('--pooling_comp', type=float, default=1.0, help='compensation') parser.add_argument('--pretrained_generator', type=str, default="") parser.add_argument('--pretrained_discriminator', type=str, default="") parser.add_argument('--initial_stage', type=float, default=0.0) parser.add_argument('--generator_smoothing', type=float, default=0.999) args = parser.parse_args() record_setting(args.out) report_keys = [ "stage", "loss_dis", "loss_gp", "loss_gen", "g", "inception_mean", "inception_std", "FID" ] max_iter = args.max_iter if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() generator = Generator() generator_smooth = Generator() discriminator = Discriminator(pooling_comp=args.pooling_comp) # select GPU if args.gpu >= 0: generator.to_gpu() generator_smooth.to_gpu() discriminator.to_gpu() print("use gpu {}".format(args.gpu)) if args.pretrained_generator != "": chainer.serializers.load_npz(args.pretrained_generator, generator) if args.pretrained_discriminator != "": chainer.serializers.load_npz(args.pretrained_discriminator, discriminator) copy_param(generator_smooth, generator) # Setup an optimizer def make_optimizer(model, alpha=0.001, beta1=0.0, beta2=0.99): optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1, beta2=beta2) optimizer.setup(model) # optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec') return optimizer opt_gen = make_optimizer(generator) opt_dis = make_optimizer(discriminator) if args.class_name == 'all_class': data_path = args.data_path one_class_flag = False else: data_path = os.path.join(args.data_path, args.class_name) one_class_flag = True train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag) train_iter = chainer.iterators.MultiprocessIterator( train_dataset, args.batchsize, n_processes=args.num_workers) # Set up a trainer updater = Updater(models=(generator, discriminator, generator_smooth), iterator={'main': train_iter}, optimizer={ 'opt_gen': opt_gen, 'opt_dis': opt_dis }, device=args.gpu, n_dis=args.n_dis, lam=args.lam, gamma=args.gamma, smoothing=args.generator_smoothing, initial_stage=args.initial_stage, stage_interval=args.stage_interval) trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out) trainer.extend(extensions.snapshot_object( generator, 'generator_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend(extensions.snapshot_object( generator_smooth, 'generator_smooth_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend(extensions.snapshot_object( discriminator, 'discriminator_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator_smooth, args.out), trigger=(args.out_image_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator_smooth, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(generator_smooth), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(generator_smooth), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--algorithm', '-a', type=str, default="dcgan", help='GAN algorithm') parser.add_argument('--architecture', type=str, default="dcgan", help='Network architecture') parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--max_iter', type=int, default=100000) parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=5, help='number of discriminator update per generator update') parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer') parser.add_argument('--adam_beta1', type=float, default=0.0, help='beta1 in Adam optimizer') parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') parser.add_argument( '--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)') parser.add_argument('--data-dir', type=str, default="") parser.add_argument('--image-npz', type=str, default="") parser.add_argument('--n-hidden', type=int, default=128) parser.add_argument('--resume', type=str, default="") parser.add_argument('--ch', type=int, default=512) parser.add_argument('--snapshot-iter', type=int, default=0) parser.add_argument('--range', type=float, default=1.0) args = parser.parse_args() record_setting(args.out) report_keys = ["loss_dis", "loss_gen"] # Set up dataset #train_dataset = Cifar10Dataset() if args.image_npz != '': from dataset.cifar10like import NPZColorDataset train_dataset = NPZColorDataset(npz=args.image_npz) elif args.data_dir != '': from dataset.cifar10like import CIFAR10Like train_dataset = CIFAR10Like(args.data_dir) train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # Setup algorithm specific networks and updaters models = [] opts = {} updater_args = {"iterator": {'main': train_iter}, "device": args.gpu} if args.algorithm == "dcgan": from dcgan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.DCGANDiscriminator() elif args.architecture == "lim_dcgan": generator = common.net.LimDCGANGenerator(n_hidden=args.n_hidden, range=args.range) discriminator = common.net.DCGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "stdgan": from stdgan.updater import Updater updater_args["n_dis"] = args.n_dis if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.DCGANDiscriminator() elif args.architecture == "sndcgan": generator = common.net.DCGANGenerator() discriminator = common.net.SNDCGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "dfm": from dfm.net import Discriminator, Denoiser from dfm.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() denoiser = Denoiser() else: raise NotImplementedError() opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha, args.adam_beta1, args.adam_beta2) report_keys.append("loss_den") models = [generator, discriminator, denoiser] elif args.algorithm == "minibatch_discrimination": from minibatch_discrimination.net import Discriminator from minibatch_discrimination.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "began": from began.net import Discriminator from began.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator(use_bn=False) discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("kt") report_keys.append("measure") updater_args["gamma"] = args.gamma elif args.algorithm == "cramer": from cramer.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator( output_dim=args.output_dim) else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "my_cramer": from cramer.my_updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator( output_dim=args.output_dim) else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "lim_cramer": #from cramer.my_updater import Updater from cramer.updater import Updater if args.architecture == "dcgan": generator = common.net.LimDCGANGenerator(range=args.range) discriminator = common.net.WGANDiscriminator( output_dim=args.output_dim) else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "dragan": from dragan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "wgan_gp": from wgan_gp.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam else: raise NotImplementedError() if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() print("use gpu {}".format(args.gpu)) for m in models: m.to_gpu() # Set up optimizers opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args["optimizer"] = opts updater_args["models"] = models # Set up updater and trainer updater = Updater(**updater_args) trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # Set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.snapshot_iter == 0: snap_iter = args.max_iter // 100 else: snap_iter = args.snapshot_iter trainer.extend(extensions.snapshot(), trigger=(snap_iter, 'iteration')) # resume if args.resume != "": chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--algorithm', '-a', type=str, default="dcgan", help='GAN algorithm') parser.add_argument('--architecture', type=str, default="dcgan", help='Network architecture') parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--z_dim', type=int, default=2) parser.add_argument('--bound', type=float, default=1.0) parser.add_argument('--h_dim', type=int, default=128) parser.add_argument('--max_iter', type=int, default=150000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=1, help='number of discriminator update per generator update') parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--shift_interval', type=float, default=10000, help='Interval of shift annealing') parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer') parser.add_argument('--adam_beta1', type=float, default=0.0, help='beta1 in Adam optimizer') parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') parser.add_argument( '--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)') args = parser.parse_args() record_setting(args.out) report_keys = ["inception_mean", "inception_std", "FID"] # report_keys = ["loss_dis", "loss_gen", "ais", "inception_mean", "inception_std", "FID"] # Set up dataset train_dataset = Cifar10Dataset() train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # Setup algorithm specific networks and updaters models = [] opts = {} updater_args = {"iterator": {'main': train_iter}, "device": args.gpu} if args.algorithm == "gan": from gan.net import Discriminator, Denoiser from gan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() denoiser = Denoiser() else: raise NotImplementedError() opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha, args.adam_beta1, args.adam_beta2) report_keys.append("loss_den") models = [generator, discriminator, denoiser] updater_args["n_dis"] = args.n_dis updater_args["shift_interval"] = args.shift_interval updater_args['n_iter'] = args.max_iter elif args.algorithm == "wgan_gp": from wgan_gp.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() denoiser = common.net.Denoiser() else: raise NotImplementedError() opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha, args.adam_beta1, args.adam_beta2) report_keys.append("loss_den") report_keys.append("loss_gp") models = [generator, discriminator, denoiser] updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam updater_args["shift_interval"] = args.shift_interval updater_args['n_iter'] = args.max_iter else: raise NotImplementedError() if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() print("use gpu {}".format(args.gpu)) for m in models: m.to_gpu() # Set up optimizers opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args["optimizer"] = opts updater_args["models"] = models # Set up updater and trainer updater = Updater(**updater_args) trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # Set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(generator), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(generator), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train StarGAN') parser.add_argument('--source_path', default="source/celebA/", help="data resource Directory") parser.add_argument('--att_list_path', default="att_list.txt", help="attribute list") parser.add_argument('--batch_size', '-b', type=int, default=16) parser.add_argument('--max_iter', '-m', type=int, default=200000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--eval_folder', '-e', default='test', help='Directory to output the evaluation result') parser.add_argument('--eval_interval', type=int, default=1000, help='Interval of evaluating generator') parser.add_argument("--learning_rate_g", type=float, default=0.0001, help="Learning rate for generator") parser.add_argument("--learning_rate_d", type=float, default=0.0001, help="Learning rate for discriminator") parser.add_argument("--load_gen_model", default='', help='load generator model') parser.add_argument("--load_dis_model", default='', help='load discriminator model') parser.add_argument('--gen_class', default='StarGAN_Generator', help='Default generator class') parser.add_argument('--dis_class', default='StarGAN_Discriminator', help='Default discriminator class') parser.add_argument("--n_dis", type=int, default=6, help='The number of loop of WGAN Discriminator') parser.add_argument("--lambda_gp", type=float, default=10.0, help='lambda for gradient penalty of WGAN') parser.add_argument("--lambda_adv", type=float, default=1.0, help='lambda for adversarial loss') parser.add_argument("--lambda_cls", type=float, default=1.0, help='lambda for classification loss') parser.add_argument("--lambda_rec", type=float, default=10.0, help='lambda for reconstruction loss') parser.add_argument("--flip", type=int, default=1, help='flip images for data augmentation') parser.add_argument("--resize_to", type=int, default=128, help='resize the image to') parser.add_argument("--crop_to", type=int, default=178, help='crop the resized image to') parser.add_argument("--load_dataset", default='celebA_train', help='load dataset') parser.add_argument("--discriminator_layer_n", type=int, default=6, help='number of discriminator layers') parser.add_argument("--learning_rate_anneal", type=float, default=10e-8, help='anneal the learning rate') parser.add_argument("--learning_rate_anneal_start", type=int, default=100000, help='time to anneal the learning') args = parser.parse_args() print(args) record_setting(args.out) max_iter = args.max_iter if args.gpu >= 0: chainer.cuda.get_device(args.gpu).use() with open(args.att_list_path, "r") as f: att_list = [] att_name = [] for line in f: line = line.strip().split(" ") if len(line) == 3: att_list.append(int(line[0])) #attID att_name.append(line[1]) #attname print("attribute list:", ",".join(att_name)) #load dataset train_dataset = getattr(celebA, args.load_dataset)(args.source_path, att_name, flip=args.flip, resize_to=args.resize_to, crop_to=args.crop_to) train_iter = chainer.iterators.MultiprocessIterator(train_dataset, args.batch_size, n_processes=4) #test_dataset = getattr(celebA, args.load_dataset)(root_celebA, flip=args.flip, resize_to=args.resize_to, crop_to=args.crop_to) test_batchsize = 8 test_iter = chainer.iterators.SerialIterator(train_dataset, test_batchsize) #set generator and discriminator nc_size = len(att_list) #num of attribute gen = getattr(net, args.gen_class)(args.resize_to, nc_size) dis = getattr(net, args.dis_class)(n_down_layers=args.discriminator_layer_n) if args.load_gen_model != '': serializers.load_npz(args.load_gen_model, gen) print("Generator model loaded") if args.load_dis_model != '': serializers.load_npz(args.load_dis_model, dis) print("Discriminator model loaded") if not os.path.exists(args.eval_folder): os.makedirs(args.eval_folder) # select GPU if args.gpu >= 0: gen.to_gpu() dis.to_gpu() print("use gpu {}".format(args.gpu)) # Setup an optimizer def make_optimizer(model, alpha=0.0001, beta1=0.5, beta2=0.999): optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1) optimizer.setup(model) return optimizer opt_gen = make_optimizer(gen, alpha=args.learning_rate_g) opt_dis = make_optimizer(dis, alpha=args.learning_rate_d) # Set up a trainer updater = Updater(models=(gen, dis), iterator={ 'main': train_iter, 'test': test_iter }, optimizer={ 'opt_gen': opt_gen, 'opt_dis': opt_dis, }, device=args.gpu, params={ 'n_dis': args.n_dis, 'lambda_adv': args.lambda_adv, 'lambda_cls': args.lambda_cls, 'lambda_rec': args.lambda_rec, 'lambda_gp': args.lambda_gp, 'image_size': args.resize_to, 'eval_folder': args.eval_folder, 'nc_size': nc_size, 'learning_rate_anneal': args.learning_rate_anneal, 'learning_rate_anneal_start': args.learning_rate_anneal_start, 'dataset': train_dataset }) model_save_interval = (4000, 'iteration') trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out) trainer.extend(extensions.snapshot_object(gen, 'gen{.updater.iteration}.npz'), trigger=model_save_interval) trainer.extend(extensions.snapshot_object(dis, 'dis{.updater.iteration}.npz'), trigger=model_save_interval) log_keys = [ 'epoch', 'iteration', 'lr_g', 'lr_d', 'loss_dis_adv', 'loss_gen_adv', 'loss_dis_cls', 'loss_gen_cls', 'loss_gen_rec', 'loss_gp' ] trainer.extend( extensions.LogReport(keys=log_keys, trigger=(20, 'iteration'))) trainer.extend(extensions.PrintReport(log_keys), trigger=(20, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=50)) trainer.extend(evaluation(gen, args.eval_folder, image_size=args.resize_to), trigger=(args.eval_interval, 'iteration')) #trainer.extend(CommandsExtension()) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--algorithm', '-a', type=str, default='wgan_gp_res', help='GAN algorithm') parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--max_iter', type=int, default=60000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID') parser.add_argument('--out', type=str, default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of display') parser.add_argument( '--n_dis', type=int, default=5, help='number of discriminator update per generator update') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--adam_alpha', type=float, default=0.0002) parser.add_argument('--adam_beta1', type=float, default=0.0) parser.add_argument('--adam_beta2', type=float, default=0.9) args = parser.parse_args() record_setting(args.out) report_keys = ['loss_dis', 'loss_gen', 'inception_mean', 'inception_std'] # set up dataset train_dataset = Cifar10Dataset() train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # set up netwroks and updaters models = [] opts = {} updater_args = {"iterator": {'main': train_iter}, "device": args.gpu} if args.algorithm == 'wgan_gp_res': from updaters.wgangp_updater import Updater import dis_models.resnet_discriminator generator = gen_models.resnet_generator.ResnetGenerator() discriminator = dis_models.resnet_discriminator.ResnetDiscriminator() models = [generator, discriminator] report_keys.append('loss_gp') updater_args['n_dis'] = args.n_dis updater_args['lam'] = args.lam elif args.algorithm == 'sngan_res': from updaters.stdgan_updater import Updater import dis_models.sn_resnet_discriminator generator = gen_models.resnet_generator.ResnetGenerator() discriminator = dis_models.sn_resnet_discriminator.SNResnetDiscriminator( ) models = [generator, discriminator] updater_args['n_dis'] = args.n_dis elif args.algorithm == 'snwgan_res': from updaters.wgan_like_updater import Updater import dis_models.sn_resnet_discriminator generator = gen_models.resnet_generator.ResnetGenerator() discriminator = dis_models.sn_resnet_discriminator.SNResnetDiscriminator( ) models = [generator, discriminator] updater_args['n_dis'] = args.n_dis else: raise NotImplementedError() if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() for m in models: m.to_gpu() # set up optimizers opts['opt_gen'] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts['opt_dis'] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args['optimizer'] = opts updater_args['models'] = models # set up updater updater = Updater(**updater_args) # set up trainer trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(extensions.PlotReport(['loss_dis', 'loss_gen'], 'iteration', trigger=(args.display_interval, 'iteration'), file_name='loss.png'), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate8(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) # trainer.extend(calc_inception(generator, batchsize=100, n_ims=1000), # trigger=(args.evaluation_interval, 'iteration'), # priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # train trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--data_path', type=str, default="data/datasets/cifar10/train", help='dataset directory path') parser.add_argument('--class_name', '-class', type=str, default='all_class', help='class name (default: all_class(str))') parser.add_argument('--num_workers', type=int, default=4, help='number of parallel data loading processes') parser.add_argument('--algorithm', '-a', type=str, default="dcgan", help='GAN algorithm') parser.add_argument('--architecture', type=str, default="dcgan", help='Network architecture') parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--max_iter', type=int, default=10000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=1000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=1000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=5, help='number of discriminator update per generator update') parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer') parser.add_argument('--adam_beta1', type=float, default=0.0, help='beta1 in Adam optimizer') parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') parser.add_argument( '--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)') args = parser.parse_args() record_setting(args.out) report_keys = [ "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID" ] # Set up dataset if args.class_name == 'all_class': data_path = args.data_path one_class_flag = False else: data_path = os.path.join(args.data_path, args.class_name) one_class_flag = True train_dataset = ImageDataset(data_path, one_class_flag=one_class_flag) train_iter = chainer.iterators.MultiprocessIterator( train_dataset, args.batchsize, n_processes=args.num_workers) # Setup algorithm specific networks and updaters models = [] opts = {} updater_args = {"iterator": {'main': train_iter}, "device": args.gpu} if args.algorithm == "dcgan": from dcgan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.DCGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "stdgan": from stdgan.updater import Updater updater_args["n_dis"] = args.n_dis if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.DCGANDiscriminator() elif args.architecture == "sndcgan": generator = common.net.DCGANGenerator() discriminator = common.net.SNDCGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "dfm": from dfm.net import Discriminator, Denoiser from dfm.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() denoiser = Denoiser() else: raise NotImplementedError() opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha, args.adam_beta1, args.adam_beta2) report_keys.append("loss_den") models = [generator, discriminator, denoiser] elif args.algorithm == "minibatch_discrimination": from minibatch_discrimination.net import Discriminator from minibatch_discrimination.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "began": from began.net import Discriminator from began.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator(use_bn=False) discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("kt") report_keys.append("measure") updater_args["gamma"] = args.gamma elif args.algorithm == "cramer": from cramer.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator( output_dim=args.output_dim) else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "dragan": from dragan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "wgan_gp": from wgan_gp.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam else: raise NotImplementedError() if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() print("use gpu {}".format(args.gpu)) for m in models: m.to_gpu() # Set up optimizers opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args["optimizer"] = opts updater_args["models"] = models # Set up updater and trainer updater = Updater(**updater_args) trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # Set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(generator), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(generator), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # Run the training trainer.run()
parser.add_argument('--out', '-o', default='result_classifier', help='Directory to output the result') parser.add_argument('--num_syms', '-n', type=int, default=3, help='Number of symbols to demod at a Time') parser.add_argument('--model_type', '-t', type=str, default="AlexStock", help='Which Model to run (AlexStock, ComplexNN)') parser.add_argument('--snr', '-s', type=int, default=18, help='SNR to use for demodulation training') args = parser.parse_args() # results_output_dir = os.path.join(os.environ['KUNGLAB_SHARE_RESULTS'], args.out) results_output_dir = args.out if not os.path.exists(results_output_dir): os.makedirs(results_output_dir) record_setting(results_output_dir) snr = range(-8,6,2) print snr num_syms = args.num_syms data_train = DemodSNRDataset(test=False, snr=snr, num_syms=num_syms) data_test = DemodSNRDataset(test=True, snr=snr, num_syms=num_syms) num_classes = np.unique(data_train.ys).shape[0] # train model if args.model_type == "AlexStock" or args.model_type == "AlexSmall": print "AlexSmall" model = L.Classifier(model_map[args.model_type](num_classes, init_weights=True, filter_height=2)) else: model = L.Classifier(model_map[args.model_type](num_classes, init_weights=True, filter_height=1))
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--algorithm', '-a', type=str, default="dcgan", help='GAN algorithm') parser.add_argument('--architecture', type=str, default="dcgan", help='Network architecture') parser.add_argument('--dataset', type=str, default="cifar10", help='Dataset') parser.add_argument('--bottom_width', type=int, default=4) parser.add_argument('--udvmode', type=int, default=1) parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--max_iter', type=int, default=100000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=5000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=5000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=5, help='number of discriminator update per generator update') parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer') parser.add_argument('--adam_beta1', type=float, default=0.0, help='beta1 in Adam optimizer') parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') parser.add_argument( '--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)') args = parser.parse_args() record_setting(args.out) report_keys = [ "loss_dis", "loss_gen", "inception_mean", "inception_std", "FID", "loss_orth" ] # Set up dataset if args.dataset == "cifar10": train_dataset = Cifar10Dataset() elif args.dataset == "stl10": train_dataset = STL10Dataset() args.bottom_width = 6 else: raise NotImplementedError() train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # Setup algorithm specific networks and updaters models = [] opts = {} updater_args = {"iterator": {'main': train_iter}, "device": args.gpu} if args.algorithm == "dcgan": from dcgan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.DCGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "stdgan": updater_args["n_dis"] = args.n_dis if args.architecture == "dcgan": from stdgan.updater import Updater generator = common.net.DCGANGenerator( bottom_width=args.bottom_width) discriminator = common.net.DCGANDiscriminator( bottom_width=args.bottom_width) elif args.architecture == "sndcgan": from stdgan.updater import Updater generator = common.net.DCGANGenerator( bottom_width=args.bottom_width) discriminator = common.net.SNDCGANDiscriminator( bottom_width=args.bottom_width) elif args.architecture == "snresdcgan": from stdgan.updater import HingeUpdater as Updater generator = common.net.ResnetGenerator( n_hidden=256, bottom_width=args.bottom_width) discriminator = common.net.SNResnetDiscriminator( bottom_width=args.bottom_width) else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "orthgan": updater_args["n_dis"] = args.n_dis if args.architecture == "orthdcgan": from orthgan.updater import Updater generator = common.net.DCGANGenerator( bottom_width=args.bottom_width) discriminator = common.net.ORTHDCGANDiscriminator( bottom_width=args.bottom_width) elif args.architecture == "orthresdcgan": from orthgan.updater import HingeUpdater as Updater generator = common.net.ResnetGenerator( n_hidden=256, bottom_width=args.bottom_width) discriminator = common.net.ORTHResnetDiscriminator( bottom_width=args.bottom_width) else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "uvgan": updater_args["n_dis"] = args.n_dis if args.architecture == "uvdcgan": from orthgan.updater import Updater generator = common.net.DCGANGenerator( bottom_width=args.bottom_width) discriminator = common.net.UVDCGANDiscriminator( args.udvmode, bottom_width=args.bottom_width) elif args.architecture == "uvresdcgan": from orthgan.updater import HingeUpdater as Updater generator = common.net.ResnetGenerator( n_hidden=256, bottom_width=args.bottom_width) discriminator = common.net.UVResnetDiscriminator( args.udvmode, bottom_width=args.bottom_width) else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "dfm": from dfm.net import Discriminator, Denoiser from dfm.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() denoiser = Denoiser() else: raise NotImplementedError() opts["opt_den"] = make_optimizer(denoiser, args.adam_alpha, args.adam_beta1, args.adam_beta2) report_keys.append("loss_den") models = [generator, discriminator, denoiser] elif args.algorithm == "minibatch_discrimination": from minibatch_discrimination.net import Discriminator from minibatch_discrimination.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] elif args.algorithm == "began": from began.net import Discriminator from began.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator(use_bn=False) discriminator = Discriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("kt") report_keys.append("measure") updater_args["gamma"] = args.gamma elif args.algorithm == "cramer": from cramer.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator( output_dim=args.output_dim) else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "dragan": from dragan.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam elif args.algorithm == "wgan_gp": from wgan_gp.updater import Updater if args.architecture == "dcgan": generator = common.net.DCGANGenerator() discriminator = common.net.WGANDiscriminator() else: raise NotImplementedError() models = [generator, discriminator] report_keys.append("loss_gp") updater_args["n_dis"] = args.n_dis updater_args["lam"] = args.lam else: raise NotImplementedError() if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() print("use gpu {}".format(args.gpu)) for m in models: m.to_gpu() # Set up optimizers opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args["optimizer"] = opts updater_args["models"] = models # Set up updater and trainer updater = Updater(**updater_args) trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # Set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) if "res" not in args.architecture: trainer.extend(sv_generate(discriminator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) IS_array = [] FID_array = [] trainer.extend(calc_inception(generator, IS_array, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(generator, FID_array, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--batchsize', type=int, default=64) parser.add_argument('--max_iter', type=int, default=100000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=10000, help='Interval of evaluation') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument('--n_dis', type=int, default=1, help='number of discriminator update per generator update') # 5 parser.add_argument('--gamma', type=float, default=0.5, help='hyperparameter gamma') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--adam_alpha', type=float, default=0.0002, help='alpha in Adam optimizer') parser.add_argument('--adam_beta1', type=float, default=0.5, help='beta1 in Adam optimizer') # 0.0 parser.add_argument('--adam_beta2', type=float, default=0.9, help='beta2 in Adam optimizer') # 0.9 parser.add_argument('--output_dim', type=int, default=256, help='output dimension of the discriminator (for cramer GAN)') parser.add_argument('--data-dir', type=str, default="") parser.add_argument('--image-npz', type=str, default="") parser.add_argument('--n-hidden', type=int, default=128) parser.add_argument('--resume', type=str, default="") parser.add_argument('--ch', type=int, default=512) parser.add_argument('--snapshot-iter', type=int, default=0) args = parser.parse_args() record_setting(args.out) report_keys = ["loss_dis", "loss_gen"] # Set up dataset if args.image_npz != '': from c128dcgan.dataset import NPZColorDataset train_dataset = NPZColorDataset(npz=args.image_npz) elif args.data_dir != '': from c128dcgan.dataset import Color128x128Dataset train_dataset = Color128x128Dataset(args.data_dir) train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # Setup algorithm specific networks and updaters models = [] opts = {} updater_args = { "iterator": {'main': train_iter}, "device": args.gpu } # fixed algorithm #from c128gan import Updater generator = common.net.C128Generator(ch=args.ch, n_hidden=args.n_hidden) discriminator = common.net.SND128Discriminator(ch=args.ch) models = [generator, discriminator] from dcgan.updater import Updater if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() print("use gpu {}".format(args.gpu)) for m in models: m.to_gpu() # Set up optimizers opts["opt_gen"] = make_optimizer(generator, args.adam_alpha, args.adam_beta1, args.adam_beta2) opts["opt_dis"] = make_optimizer(discriminator, args.adam_alpha, args.adam_beta1, args.adam_beta2) updater_args["optimizer"] = opts updater_args["models"] = models # Set up updater and trainer updater = Updater(**updater_args) trainer = training.Trainer(updater, (args.max_iter, 'iteration'), out=args.out) # Set up logging for m in models: trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend(extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator, args.out), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.snapshot_iter == 0: snap_iter= args.max_iter // 100 else: snap_iter = args.snapshot_iter trainer.extend(extensions.snapshot(), trigger=(snap_iter , 'iteration')) # resume if args.resume != "": chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()