def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='configs/sr.yml', help='path to config file') parser.add_argument('--gpu', type=int, default=0, help='index of gpu to be used') parser.add_argument('--results_dir', type=str, default='./results', help='directory to save the results to') parser.add_argument( '--inception_model_path', type=str, default='./datasets/inception_model/inception_score.model', help='path to the inception model') parser.add_argument( '--stat_file', type=str, default='./datasets/inception_model/fid_stats_cifar10_train.npz', help='path to the inception model') parser.add_argument('--snapshot', type=str, default='', help='path to the snapshot') parser.add_argument('--loaderjob', type=int, help='number of parallel data loading processes') args = parser.parse_args() config = yaml_utils.Config(yaml.load(open(args.config_path))) chainer.cuda.get_device_from_id(args.gpu).use() # set up the model devices = {'main': 0, 'second': 1, 'third': 2, 'fourth': 3} gen, dis = load_models(config) model = {"gen": gen, "dis": dis} names = list(six.iterkeys(devices)) try: names.remove('main') except ValueError: raise KeyError("devices must contain a 'main' key.") models = {'main': model} for name in names: g = copy.deepcopy(model['gen']) d = copy.deepcopy(model['dis']) if devices[name] >= 0: g.to_gpu(device=devices[name]) d.to_gpu(device=devices[name]) models[name] = {"gen": g, "dis": d} if devices['main'] >= 0: models['main']['gen'].to_gpu(device=devices['main']) models['main']['dis'].to_gpu(device=devices['main']) links = [[name, link] for name, link in sorted(dis.namedlinks())] for name, link in links: print(name) links = [[name, link] for name, link in sorted(gen.namedlinks())] for name, link in links: print(name) # Optimizer opt_gen = make_optimizer(models['main']['gen'], alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opt_dis = make_optimizer(models['main']['dis'], alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opts = {"opt_gen": opt_gen, "opt_dis": opt_dis} # Dataset dataset = yaml_utils.load_dataset(config) # Iterator iterator = chainer.iterators.MultiprocessIterator( dataset, config.batchsize, n_processes=args.loaderjob) kwargs = config.updater['args'] if 'args' in config.updater else {} kwargs.update({ 'devices': devices, 'models': models, 'iterator': iterator, 'optimizer': opts }) updater = yaml_utils.load_updater_class(config) updater = updater(**kwargs) out = args.results_dir create_result_dir(out, args.config_path, config) trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out) report_keys = ["loss_dis", "loss_gen", "inception_mean", "FID"] # Set up logging trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration')) for m in models['main'].values(): trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(config.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration')) if gen.n_classes > 0: trainer.extend(sample_generate_conditional(models['main']['gen'], out, n_classes=gen.n_classes), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) else: trainer.extend(sample_generate(models['main']['gen'], out), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(models['main']['gen'], out, rows=10, cols=10), trigger=(config.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(models['main']['gen'], n_ims=5000, splits=1, dst=args.results_dir, path=args.inception_model_path), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(models['main']['gen'], n_ims=5000, dst=args.results_dir, path=args.inception_model_path, stat_file=args.stat_file), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(monitor_largest_singular_values(models['main']['dis'], dst=args.results_dir), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend( extensions.ProgressBar(update_interval=config.progressbar_interval)) ext_opt_gen = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_gen) ext_opt_dis = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_dis) trainer.extend(ext_opt_gen) trainer.extend(ext_opt_dis) if args.snapshot: print("Resume training with snapshot:{}".format(args.snapshot)) chainer.serializers.load_npz(args.snapshot, trainer) # Run the training print("start training") trainer.run()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file') parser.add_argument('--gpu', type=int, default=0, help='index of gpu to be used') parser.add_argument('--data_dir', type=str, default='./data/imagenet') parser.add_argument('--results_dir', type=str, default='./results/gans', help='directory to save the results to') parser.add_argument('--inception_model_path', type=str, default='./datasets/inception_model', help='path to the inception model') parser.add_argument('--snapshot', type=str, default='', help='path to the snapshot') parser.add_argument('--loaderjob', type=int, help='number of parallel data loading processes') args = parser.parse_args() config = yaml_utils.Config(yaml.load(open(args.config_path))) chainer.cuda.get_device_from_id(args.gpu).use() gen, dis = load_models(config) gen.to_gpu(device=args.gpu) dis.to_gpu(device=args.gpu) models = {"gen": gen, "dis": dis} # Optimizer opt_gen = make_optimizer(gen, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opt_dis = make_optimizer(dis, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opts = {"opt_gen": opt_gen, "opt_dis": opt_dis} # Dataset config['dataset']['args']['root'] = args.data_dir dataset = yaml_utils.load_dataset(config) # Iterator iterator = chainer.iterators.MultiprocessIterator( dataset, config.batchsize, n_processes=args.loaderjob) kwargs = config.updater['args'] if 'args' in config.updater else {} kwargs.update({ 'models': models, 'iterator': iterator, 'optimizer': opts, }) updater = yaml_utils.load_updater_class(config) updater = updater(**kwargs) out = args.results_dir create_result_dir(out, args.config_path, config) trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out) report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"] # Set up logging trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration')) for m in models.values(): trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(config.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration')) trainer.extend(sample_generate_conditional(gen, out, n_classes=gen.n_classes), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(gen, out, rows=10, cols=10), trigger=(config.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(gen, n_ims=5000, splits=1, path=args.inception_model_path), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend( extensions.ProgressBar(update_interval=config.progressbar_interval)) ext_opt_gen = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_gen) ext_opt_dis = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_dis) trainer.extend(ext_opt_gen) trainer.extend(ext_opt_dis) if args.snapshot: print("Resume training with snapshot:{}".format(args.snapshot)) chainer.serializers.load_npz(args.snapshot, trainer) # Run the training print("start training") trainer.run()
def main(): parser = argparse.ArgumentParser(description='Train script') parser.add_argument('--batchsize', '-b', type=int, default=16) parser.add_argument('--max_iter', '-m', type=int, default=400000) parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--snapshot_interval', type=int, default=25000, help='Interval of snapshot') parser.add_argument('--evaluation_interval', type=int, default=50000, help='Interval of evaluation') parser.add_argument('--out_image_interval', type=int, default=12500, help='Interval of evaluation') parser.add_argument('--stage_interval', type=int, default=400000, help='Interval of stage progress') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument( '--n_dis', type=int, default=1, help='number of discriminator update per generator update') parser.add_argument('--lam', type=float, default=10, help='gradient penalty') parser.add_argument('--gamma', type=float, default=750, help='gradient penalty') parser.add_argument('--pooling_comp', type=float, default=1.0, help='compensation') parser.add_argument('--pretrained_generator', type=str, default="") parser.add_argument('--pretrained_discriminator', type=str, default="") parser.add_argument('--initial_stage', type=float, default=0.0) parser.add_argument('--generator_smoothing', type=float, default=0.999) args = parser.parse_args() record_setting(args.out) report_keys = [ "stage", "loss_dis", "loss_gp", "loss_gen", "g", "inception_mean", "inception_std", "FID" ] max_iter = args.max_iter if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() generator = Generator() generator_smooth = Generator() discriminator = Discriminator(pooling_comp=args.pooling_comp) # select GPU if args.gpu >= 0: generator.to_gpu() generator_smooth.to_gpu() discriminator.to_gpu() print("use gpu {}".format(args.gpu)) if args.pretrained_generator != "": chainer.serializers.load_npz(args.pretrained_generator, generator) if args.pretrained_discriminator != "": chainer.serializers.load_npz(args.pretrained_discriminator, discriminator) copy_param(generator_smooth, generator) # Setup an optimizer def make_optimizer(model, alpha=0.001, beta1=0.0, beta2=0.99): optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1, beta2=beta2) optimizer.setup(model) # optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec') return optimizer opt_gen = make_optimizer(generator) opt_dis = make_optimizer(discriminator) train_dataset = Cifar10Dataset() train_iter = chainer.iterators.SerialIterator(train_dataset, args.batchsize) # Set up a trainer updater = Updater(models=(generator, discriminator, generator_smooth), iterator={'main': train_iter}, optimizer={ 'opt_gen': opt_gen, 'opt_dis': opt_dis }, device=args.gpu, n_dis=args.n_dis, lam=args.lam, gamma=args.gamma, smoothing=args.generator_smoothing, initial_stage=args.initial_stage, stage_interval=args.stage_interval) trainer = training.Trainer(updater, (max_iter, 'iteration'), out=args.out) trainer.extend(extensions.snapshot_object( generator, 'generator_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend(extensions.snapshot_object( generator_smooth, 'generator_smooth_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend(extensions.snapshot_object( discriminator, 'discriminator_{.updater.iteration}.npz'), trigger=(args.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(args.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(args.display_interval, 'iteration')) trainer.extend(sample_generate(generator_smooth, args.out), trigger=(args.out_image_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(generator_smooth, args.out), trigger=(args.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(generator_smooth), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_FID(generator_smooth), trigger=(args.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(extensions.ProgressBar(update_interval=10)) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file') parser.add_argument('--data_dir', type=str, default='./data/imagenet') parser.add_argument('--results_dir', type=str, default='./results/gans', help='directory to save the results to') parser.add_argument('--inception_model_path', type=str, default='./datasets/inception_model', help='path to the inception model') parser.add_argument('--snapshot', type=str, default='', help='path to the snapshot') parser.add_argument('--loaderjob', type=int, help='number of parallel data loading processes') parser.add_argument('--communicator', type=str, default='hierarchical', help='Type of communicator') args = parser.parse_args() config = yaml_utils.Config(yaml.load(open(args.config_path))) comm = chainermn.create_communicator(args.communicator) device = comm.intra_rank chainer.cuda.get_device_from_id(device).use() print("init") multiprocessing.set_start_method('forkserver') if comm.rank == 0: print('==========================================') print('Using {} communicator'.format(args.communicator)) print('==========================================') # Model gen, dis = load_models(config) gen.to_gpu() dis.to_gpu() models = {"gen": gen, "dis": dis} # Optimizer opt_gen = make_optimizer(gen, comm, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opt_dis = make_optimizer(dis, comm, alpha=config.adam['alpha'], beta1=config.adam['beta1'], beta2=config.adam['beta2']) opts = {"opt_gen": opt_gen, "opt_dis": opt_dis} # Dataset if config['dataset'][ 'dataset_name'] != 'CIFAR10Dataset': # Cifar10 dataset handler does not take "root" as argument. config['dataset']['args']['root'] = args.data_dir if comm.rank == 0: dataset = yaml_utils.load_dataset(config) else: _ = yaml_utils.load_dataset( config) # Dummy, for adding path to the dataset module dataset = None dataset = chainermn.scatter_dataset(dataset, comm) # Iterator iterator = chainer.iterators.MultiprocessIterator( dataset, config.batchsize, n_processes=args.loaderjob) kwargs = config.updater['args'] if 'args' in config.updater else {} kwargs.update({ 'models': models, 'iterator': iterator, 'optimizer': opts, 'device': device, }) updater = yaml_utils.load_updater_class(config) updater = updater(**kwargs) out = args.results_dir if comm.rank == 0: create_result_dir(out, args.config_path, config) trainer = training.Trainer(updater, (config.iteration, 'iteration'), out=out) report_keys = ["loss_dis", "loss_gen", "inception_mean", "inception_std"] if comm.rank == 0: # Set up logging trainer.extend(extensions.snapshot(), trigger=(config.snapshot_interval, 'iteration')) for m in models.values(): trainer.extend(extensions.snapshot_object( m, m.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=(config.snapshot_interval, 'iteration')) trainer.extend( extensions.LogReport(keys=report_keys, trigger=(config.display_interval, 'iteration'))) trainer.extend(extensions.PrintReport(report_keys), trigger=(config.display_interval, 'iteration')) if gen.n_classes > 0: trainer.extend(sample_generate_conditional( gen, out, n_classes=gen.n_classes), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) else: trainer.extend(sample_generate(gen, out), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(sample_generate_light(gen, out, rows=10, cols=10), trigger=(config.evaluation_interval // 10, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend(calc_inception(gen, n_ims=5000, splits=1, path=args.inception_model_path), trigger=(config.evaluation_interval, 'iteration'), priority=extension.PRIORITY_WRITER) trainer.extend( extensions.ProgressBar( update_interval=config.progressbar_interval)) ext_opt_gen = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_gen) ext_opt_dis = extensions.LinearShift( 'alpha', (config.adam['alpha'], 0.), (config.iteration_decay_start, config.iteration), opt_dis) trainer.extend(ext_opt_gen) trainer.extend(ext_opt_dis) if args.snapshot: print("Resume training with snapshot:{}".format(args.snapshot)) chainer.serializers.load_npz(args.snapshot, trainer) # Run the training print("start training") trainer.run()