예제 #1
0
파일: baseline.py 프로젝트: zlpsls/pose-gan
def main():
    generator = make_generator()
    discriminator = make_discriminator()

    args = parser_with_default_args().parse_args()
    dataset = FolderDataset(args.input_folder, args.batch_size, (128, ),
                            (128, 64))
    gan = WGAN_GP(generator, discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
예제 #2
0
def main():
    parser = parser_with_default_args()
    parser.add_argument("--images_folder",
                        default="data/photo/tx_000100000000",
                        help='Folder with photos')
    parser.add_argument("--sketch_folder",
                        default="data/sketch/tx_000000000000",
                        help='Folder with sketches')
    parser.add_argument(
        "--invalid_files",
        default=[
            'data/info/invalid-ambiguous.txt', 'data/info/invalid-context.txt',
            'data/info/invalid-error.txt', 'data/info/invalid-pose.txt'
        ],
        help='List of files with invalid sketches, comma separated',
        type=lambda x: x.split(','))
    parser.add_argument("--test_set",
                        default='data/info/testset.txt',
                        help='File with test set')
    parser.add_argument("--image_size",
                        default=(64, 64),
                        help='Size of the images')
    parser.add_argument(
        "--number_of_classes",
        default=2,
        help='Number of classes to train on, usefull for debugging')
    parser.add_argument("--cache_dir",
                        default='tmp',
                        help='Store distance transforms to this folder.')

    args = parser.parse_args()

    dataset = SketchDataset(images_folder=args.images_folder,
                            sketch_folder=args.sketch_folder,
                            batch_size=args.batch_size,
                            invalid_images_files=args.invalid_files,
                            test_set=args.test_set,
                            number_of_classes=args.number_of_classes,
                            image_size=args.image_size)

    generator = make_generator(image_size=args.image_size,
                               number_of_classes=args.number_of_classes)
    discriminator = make_discriminator(
        image_size=args.image_size, number_of_classes=args.number_of_classes)
    generator.summary()
    discriminator.summary()

    gan = CGAN(generator=generator, discriminator=discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
예제 #3
0
def main():
    generator = make_generator()
    discriminator = make_discriminator()
    generator.summary()
    discriminator.summary()

    args = parser_with_default_args().parse_args()
    dataset = MNISTDataset(args.batch_size)
    gan = GDAN(generator=generator, discriminator=discriminator, **vars(args))

    hook = partial(compute_scores, generator=generator, dataset=dataset, image_shape=(28, 28, 1),
                   compute_inception=False)
    trainer = Trainer(dataset, gan, at_store_checkpoint_hook=hook, **vars(args))
    trainer.train()
예제 #4
0
def main():
    generator = make_generator()
    discriminator = make_discriminator()

    parser = parser_with_default_args()
    parser.add_argument("--input_dir", default='dataset/devian_art')
    parser.add_argument("--cache_file_name", default='output/devian_art.npy')
    parser.add_argument("--content_image",
                        default='sup-mat/cornell_cropped.jpg')

    args = parser.parse_args()

    dataset = StylesDataset(args.batch_size,
                            args.input_dir,
                            cache_file_name=args.cache_file_name,
                            content_image=args.content_image)
    gan = GAN(generator, discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
예제 #5
0
파일: run.py 프로젝트: zzy950117/wc-gan
def main():
    parser = parser_with_default_args()
    parser.add_argument(
        "--name",
        default="gan",
        help="Name of the experiment (it will create corresponding folder)")
    parser.add_argument(
        "--phase",
        choices=['train', 'test'],
        default='train',
        help=
        "Train or test, test only compute scores and generate grid of images."
        "For test generator checkpoint should be given.")

    parser.add_argument("--dataset",
                        default='cifar10',
                        choices=[
                            'mnist', 'cifar10', 'cifar100', 'fashion-mnist',
                            'stl10', 'imagenet', 'tiny-imagenet'
                        ],
                        help='Dataset to train on')
    parser.add_argument("--arch",
                        default='res',
                        choices=['res', 'dcgan'],
                        help="Gan architecture resnet or dcgan.")

    parser.add_argument("--generator_lr",
                        default=2e-4,
                        type=float,
                        help="Learning rate")
    parser.add_argument("--discriminator_lr",
                        default=2e-4,
                        type=float,
                        help="Learning rate")

    parser.add_argument("--beta1",
                        default=0,
                        type=float,
                        help='Adam parameter')
    parser.add_argument("--beta2",
                        default=0.9,
                        type=float,
                        help='Adam parameter')
    parser.add_argument(
        "--lr_decay_schedule",
        default=None,
        help='Learnign rate decay schedule:'
        'None - no decay.'
        'linear - linear decay to zero.'
        'half-linear - linear decay to 0.5'
        'linear-end - constant until 0.9, then linear decay to 0. '
        'dropat30 - drop lr 10 times at 30 epoch (any number insdead of 30 allowed).'
    )

    parser.add_argument("--generator_spectral",
                        default=0,
                        type=int,
                        help='Use spectral norm in generator.')
    parser.add_argument("--discriminator_spectral",
                        default=0,
                        type=int,
                        help='Use spectral norm in discriminator.')

    parser.add_argument("--fully_diff_spectral",
                        default=0,
                        type=int,
                        help='Fully difirentiable spectral normalization.')
    parser.add_argument("--spectral_iterations",
                        default=1,
                        type=int,
                        help='Number of iteration per spectral update.')
    parser.add_argument("--conv_singular",
                        default=0,
                        type=int,
                        help='Use convolutional spectral normalization.')

    parser.add_argument("--gan_type",
                        default=None,
                        choices=[None, 'AC_GAN', 'PROJECTIVE'],
                        help='Type of gan to use. None for unsuperwised.')

    parser.add_argument("--filters_emb",
                        default=10,
                        type=int,
                        help='Number of inner filters in factorized conv.')

    parser.add_argument(
        "--generator_block_norm",
        default='b',
        choices=['n', 'b', 'd', 'dr'],
        help=
        'Normalization in generator block. b - batch, d - whitening, n - none, '
        'dr - whitening with renornaliazation.')
    parser.add_argument(
        "--generator_block_after_norm",
        default='ucs',
        choices=[
            'ccs', 'fconv', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv',
            'ucconv', 'ccsuconv', 'n'
        ],
        help=
        "Layer after block normalization. ccs - conditional shift and scale."
        "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa."
        "n - None.")
    parser.add_argument(
        "--generator_last_norm",
        default='b',
        choices=['n', 'b', 'd', 'dr'],
        help=
        'Normalization in generator block. b - batch, d - whitening, n - none, '
        'dr - whitening with renornaliazation.')
    parser.add_argument(
        "--generator_last_after_norm",
        default='ucs',
        choices=[
            'ccs', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv', 'ucconv',
            'ccsuconv', 'n'
        ],
        help=
        "Layer after block normalization. ccs - conditional shift and scale."
        "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa."
        "n - None.")
    parser.add_argument(
        "--generator_batch_multiple",
        default=2,
        type=int,
        help="Size of the generator batch, multiple of batch_size.")
    parser.add_argument("--generator_concat_cls",
                        default=0,
                        type=int,
                        help='Concat labels to noise in generator.')
    parser.add_argument("--generator_filters",
                        default=128,
                        type=int,
                        help='Base number of filters in generator block.')

    parser.add_argument(
        "--discriminator_norm",
        default='n',
        choices=['n', 'b', 'd', 'dr'],
        help=
        'Normalization in disciminator block. b - batch, d - whitening, n - none, '
        'dr - whitening with renornaliazation.')
    parser.add_argument(
        "--discriminator_after_norm",
        default='n',
        choices=[
            'ccs', 'fconv', 'ucs', 'uccs', 'ufconv', 'cconv', 'uconv',
            'ucconv', 'ccsuconv', 'n'
        ],
        help=
        "Layer after block normalization. ccs - conditional shift and scale."
        "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa."
        "n - None.")
    parser.add_argument("--discriminator_filters",
                        default=128,
                        type=int,
                        help='Base number of filters in discriminator block.')
    parser.add_argument("--discriminator_dropout",
                        type=float,
                        default=0,
                        help="Use dropout in discriminator.")
    parser.add_argument("--shred_disc_batch",
                        type=int,
                        default=0,
                        help='Shred batch in discriminator to save memory')

    parser.add_argument("--sum_pool",
                        default=1,
                        type=int,
                        help='Use sum or average pooling in discriminator.')

    parser.add_argument("--samples_inception",
                        default=50000,
                        type=int,
                        help='Samples for IS score, 0 - no compute inception')
    parser.add_argument("--samples_fid",
                        default=10000,
                        type=int,
                        help="Samples for FID score, 0 - no compute FID")

    args = parser.parse_args()

    dataset = get_dataset(dataset=args.dataset,
                          batch_size=args.batch_size,
                          supervised=args.gan_type is not None)

    args.output_dir = "output/%s_%s_%s" % (args.name, args.phase, time())
    print args.output_dir
    args.checkpoints_dir = args.output_dir
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as outfile:
        json.dump(vars(args), outfile, indent=4)

    image_shape_dict = {
        'mnist': (28, 28, 1),
        'fashion-mnist': (28, 28, 1),
        'cifar10': (32, 32, 3),
        'cifar100': (32, 32, 3),
        'stl10': (48, 48, 3),
        'imagenet': (128, 128, 3),
        'tiny-imagenet': (64, 64, 3)
    }

    args.image_shape = image_shape_dict[args.dataset]
    print("Image shape %s x %s x %s" % args.image_shape)
    args.fid_cache_file = "output/%s_fid.npz" % args.dataset

    discriminator_params = get_discriminator_params(args)
    generator_params = get_generator_params(args)

    del args.dataset

    compile_and_run(dataset, args, generator_params, discriminator_params)