def main(): generator = make_generator() discriminator = make_discriminator() args = parser_with_default_args().parse_args() dataset = FolderDataset(args.input_folder, args.batch_size, (128, ), (128, 64)) gan = WGAN_GP(generator, discriminator, **vars(args)) trainer = Trainer(dataset, gan, **vars(args)) trainer.train()
def main(): parser = parser_with_default_args() parser.add_argument("--images_folder", default="data/photo/tx_000100000000", help='Folder with photos') parser.add_argument("--sketch_folder", default="data/sketch/tx_000000000000", help='Folder with sketches') parser.add_argument( "--invalid_files", default=[ 'data/info/invalid-ambiguous.txt', 'data/info/invalid-context.txt', 'data/info/invalid-error.txt', 'data/info/invalid-pose.txt' ], help='List of files with invalid sketches, comma separated', type=lambda x: x.split(',')) parser.add_argument("--test_set", default='data/info/testset.txt', help='File with test set') parser.add_argument("--image_size", default=(64, 64), help='Size of the images') parser.add_argument( "--number_of_classes", default=2, help='Number of classes to train on, usefull for debugging') parser.add_argument("--cache_dir", default='tmp', help='Store distance transforms to this folder.') args = parser.parse_args() dataset = SketchDataset(images_folder=args.images_folder, sketch_folder=args.sketch_folder, batch_size=args.batch_size, invalid_images_files=args.invalid_files, test_set=args.test_set, number_of_classes=args.number_of_classes, image_size=args.image_size) generator = make_generator(image_size=args.image_size, number_of_classes=args.number_of_classes) discriminator = make_discriminator( image_size=args.image_size, number_of_classes=args.number_of_classes) generator.summary() discriminator.summary() gan = CGAN(generator=generator, discriminator=discriminator, **vars(args)) trainer = Trainer(dataset, gan, **vars(args)) trainer.train()
def main(): generator = make_generator() discriminator = make_discriminator() generator.summary() discriminator.summary() args = parser_with_default_args().parse_args() dataset = MNISTDataset(args.batch_size) gan = GDAN(generator=generator, discriminator=discriminator, **vars(args)) hook = partial(compute_scores, generator=generator, dataset=dataset, image_shape=(28, 28, 1), compute_inception=False) trainer = Trainer(dataset, gan, at_store_checkpoint_hook=hook, **vars(args)) trainer.train()
def main(): generator = make_generator() discriminator = make_discriminator() parser = parser_with_default_args() parser.add_argument("--input_dir", default='dataset/devian_art') parser.add_argument("--cache_file_name", default='output/devian_art.npy') parser.add_argument("--content_image", default='sup-mat/cornell_cropped.jpg') args = parser.parse_args() dataset = StylesDataset(args.batch_size, args.input_dir, cache_file_name=args.cache_file_name, content_image=args.content_image) gan = GAN(generator, discriminator, **vars(args)) trainer = Trainer(dataset, gan, **vars(args)) trainer.train()
def main(): parser = parser_with_default_args() parser.add_argument( "--name", default="gan", help="Name of the experiment (it will create corresponding folder)") parser.add_argument( "--phase", choices=['train', 'test'], default='train', help= "Train or test, test only compute scores and generate grid of images." "For test generator checkpoint should be given.") parser.add_argument("--dataset", default='cifar10', choices=[ 'mnist', 'cifar10', 'cifar100', 'fashion-mnist', 'stl10', 'imagenet', 'tiny-imagenet' ], help='Dataset to train on') parser.add_argument("--arch", default='res', choices=['res', 'dcgan'], help="Gan architecture resnet or dcgan.") parser.add_argument("--generator_lr", default=2e-4, type=float, help="Learning rate") parser.add_argument("--discriminator_lr", default=2e-4, type=float, help="Learning rate") parser.add_argument("--beta1", default=0, type=float, help='Adam parameter') parser.add_argument("--beta2", default=0.9, type=float, help='Adam parameter') parser.add_argument( "--lr_decay_schedule", default=None, help='Learnign rate decay schedule:' 'None - no decay.' 'linear - linear decay to zero.' 'half-linear - linear decay to 0.5' 'linear-end - constant until 0.9, then linear decay to 0. ' 'dropat30 - drop lr 10 times at 30 epoch (any number insdead of 30 allowed).' ) parser.add_argument("--generator_spectral", default=0, type=int, help='Use spectral norm in generator.') parser.add_argument("--discriminator_spectral", default=0, type=int, help='Use spectral norm in discriminator.') parser.add_argument("--fully_diff_spectral", default=0, type=int, help='Fully difirentiable spectral normalization.') parser.add_argument("--spectral_iterations", default=1, type=int, help='Number of iteration per spectral update.') parser.add_argument("--conv_singular", default=0, type=int, help='Use convolutional spectral normalization.') parser.add_argument("--gan_type", default=None, choices=[None, 'AC_GAN', 'PROJECTIVE'], help='Type of gan to use. None for unsuperwised.') parser.add_argument("--filters_emb", default=10, type=int, help='Number of inner filters in factorized conv.') parser.add_argument( "--generator_block_norm", default='b', choices=['n', 'b', 'd', 'dr'], help= 'Normalization in generator block. b - batch, d - whitening, n - none, ' 'dr - whitening with renornaliazation.') parser.add_argument( "--generator_block_after_norm", default='ucs', choices=[ 'ccs', 'fconv', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv', 'ucconv', 'ccsuconv', 'n' ], help= "Layer after block normalization. ccs - conditional shift and scale." "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa." "n - None.") parser.add_argument( "--generator_last_norm", default='b', choices=['n', 'b', 'd', 'dr'], help= 'Normalization in generator block. b - batch, d - whitening, n - none, ' 'dr - whitening with renornaliazation.') parser.add_argument( "--generator_last_after_norm", default='ucs', choices=[ 'ccs', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv', 'ucconv', 'ccsuconv', 'n' ], help= "Layer after block normalization. ccs - conditional shift and scale." "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa." "n - None.") parser.add_argument( "--generator_batch_multiple", default=2, type=int, help="Size of the generator batch, multiple of batch_size.") parser.add_argument("--generator_concat_cls", default=0, type=int, help='Concat labels to noise in generator.') parser.add_argument("--generator_filters", default=128, type=int, help='Base number of filters in generator block.') parser.add_argument( "--discriminator_norm", default='n', choices=['n', 'b', 'd', 'dr'], help= 'Normalization in disciminator block. b - batch, d - whitening, n - none, ' 'dr - whitening with renornaliazation.') parser.add_argument( "--discriminator_after_norm", default='n', choices=[ 'ccs', 'fconv', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv', 'ucconv', 'ccsuconv', 'n' ], help= "Layer after block normalization. ccs - conditional shift and scale." "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa." "n - None.") parser.add_argument("--discriminator_filters", default=128, type=int, help='Base number of filters in discriminator block.') parser.add_argument("--discriminator_dropout", type=float, default=0, help="Use dropout in discriminator.") parser.add_argument("--shred_disc_batch", type=int, default=0, help='Shred batch in discriminator to save memory') parser.add_argument("--sum_pool", default=1, type=int, help='Use sum or average pooling in discriminator.') parser.add_argument("--samples_inception", default=50000, type=int, help='Samples for IS score, 0 - no compute inception') parser.add_argument("--samples_fid", default=10000, type=int, help="Samples for FID score, 0 - no compute FID") args = parser.parse_args() dataset = get_dataset(dataset=args.dataset, batch_size=args.batch_size, supervised=args.gan_type is not None) args.output_dir = "output/%s_%s_%s" % (args.name, args.phase, time()) print args.output_dir args.checkpoints_dir = args.output_dir if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, 'config.json'), 'w') as outfile: json.dump(vars(args), outfile, indent=4) image_shape_dict = { 'mnist': (28, 28, 1), 'fashion-mnist': (28, 28, 1), 'cifar10': (32, 32, 3), 'cifar100': (32, 32, 3), 'stl10': (48, 48, 3), 'imagenet': (128, 128, 3), 'tiny-imagenet': (64, 64, 3) } args.image_shape = image_shape_dict[args.dataset] print("Image shape %s x %s x %s" % args.image_shape) args.fid_cache_file = "output/%s_fid.npz" % args.dataset discriminator_params = get_discriminator_params(args) generator_params = get_generator_params(args) del args.dataset compile_and_run(dataset, args, generator_params, discriminator_params)