Esempio n. 1
0
def compile_and_run(dataset, args, generator_params, discriminator_params):
    additional_info = json.dumps(vars(args))

    args.generator_optimizer = Adam(args.generator_lr,
                                    beta_1=args.beta1,
                                    beta_2=args.beta2)
    args.discriminator_optimizer = Adam(args.discriminator_lr,
                                        beta_1=args.beta1,
                                        beta_2=args.beta2)

    log_file = os.path.join(args.output_dir, 'log.txt')

    at_store_checkpoint_hook = partial(compute_scores,
                                       image_shape=args.image_shape,
                                       log_file=log_file,
                                       dataset=dataset,
                                       images_inception=args.samples_inception,
                                       images_fid=args.samples_fid,
                                       additional_info=additional_info,
                                       cache_file=args.fid_cache_file)

    lr_decay_schedule_generator, lr_decay_schedule_discriminator = get_lr_decay_schedule(
        args)

    generator_checkpoint = args.generator_checkpoint
    discriminator_checkpoint = args.discriminator_checkpoint

    generator = make_generator(**vars(generator_params))
    discriminator = make_discriminator(**vars(discriminator_params))

    generator.summary()
    discriminator.summary()

    if generator_checkpoint is not None:
        generator.load_weights(generator_checkpoint)

    if discriminator_checkpoint is not None:
        discriminator.load_weights(discriminator_checkpoint)

    hook = partial(at_store_checkpoint_hook, generator=generator)

    if args.phase == 'train':
        GANS = {None: GAN, 'AC_GAN': AC_GAN, 'PROJECTIVE': ProjectiveGAN}
        gan = GANS[args.gan_type](
            generator=generator,
            discriminator=discriminator,
            lr_decay_schedule_discriminator=lr_decay_schedule_discriminator,
            lr_decay_schedule_generator=lr_decay_schedule_generator,
            **vars(args))
        trainer = Trainer(dataset,
                          gan,
                          at_store_checkpoint_hook=hook,
                          **vars(args))
        trainer.train()
    else:
        hook(0)
Esempio n. 2
0
def main():
    generator = make_generator()
    discriminator = make_discriminator()

    args = parser_with_default_args().parse_args()
    dataset = FolderDataset(args.input_folder, args.batch_size, (128, ),
                            (128, 64))
    gan = WGAN_GP(generator, discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
Esempio n. 3
0
def main():
    parser = parser_with_default_args()
    parser.add_argument("--images_folder",
                        default="data/photo/tx_000100000000",
                        help='Folder with photos')
    parser.add_argument("--sketch_folder",
                        default="data/sketch/tx_000000000000",
                        help='Folder with sketches')
    parser.add_argument(
        "--invalid_files",
        default=[
            'data/info/invalid-ambiguous.txt', 'data/info/invalid-context.txt',
            'data/info/invalid-error.txt', 'data/info/invalid-pose.txt'
        ],
        help='List of files with invalid sketches, comma separated',
        type=lambda x: x.split(','))
    parser.add_argument("--test_set",
                        default='data/info/testset.txt',
                        help='File with test set')
    parser.add_argument("--image_size",
                        default=(64, 64),
                        help='Size of the images')
    parser.add_argument(
        "--number_of_classes",
        default=2,
        help='Number of classes to train on, usefull for debugging')
    parser.add_argument("--cache_dir",
                        default='tmp',
                        help='Store distance transforms to this folder.')

    args = parser.parse_args()

    dataset = SketchDataset(images_folder=args.images_folder,
                            sketch_folder=args.sketch_folder,
                            batch_size=args.batch_size,
                            invalid_images_files=args.invalid_files,
                            test_set=args.test_set,
                            number_of_classes=args.number_of_classes,
                            image_size=args.image_size)

    generator = make_generator(image_size=args.image_size,
                               number_of_classes=args.number_of_classes)
    discriminator = make_discriminator(
        image_size=args.image_size, number_of_classes=args.number_of_classes)
    generator.summary()
    discriminator.summary()

    gan = CGAN(generator=generator, discriminator=discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
Esempio n. 4
0
def main():
    generator = make_generator()
    discriminator = make_discriminator()
    generator.summary()
    discriminator.summary()

    args = parser_with_default_args().parse_args()
    dataset = MNISTDataset(args.batch_size)
    gan = GDAN(generator=generator, discriminator=discriminator, **vars(args))

    hook = partial(compute_scores, generator=generator, dataset=dataset, image_shape=(28, 28, 1),
                   compute_inception=False)
    trainer = Trainer(dataset, gan, at_store_checkpoint_hook=hook, **vars(args))
    trainer.train()
Esempio n. 5
0
def main():
    args = cmd.args()

    generator = make_generator(args.image_size, args.use_input_pose, args.warp_skip, args.disc_type, args.warp_agg)
    if args.generator_checkpoint is not None:
        generator.load_weights(args.generator_checkpoint)
    
    discriminator = make_discriminator(args.image_size, args.use_input_pose, args.warp_skip, args.disc_type, args.warp_agg)
    if args.discriminator_checkpoint is not None:
        discriminator.load_weights(args.discriminator_checkpoint)
    
    dataset = PoseHMDataset(test_phase=False, **vars(args))
    
    gan = CGAN(generator, discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))
    
    trainer.train()
Esempio n. 6
0
def main():
    generator = make_generator()
    discriminator = make_discriminator()

    parser = parser_with_default_args()
    parser.add_argument("--input_dir", default='dataset/devian_art')
    parser.add_argument("--cache_file_name", default='output/devian_art.npy')
    parser.add_argument("--content_image",
                        default='sup-mat/cornell_cropped.jpg')

    args = parser.parse_args()

    dataset = StylesDataset(args.batch_size,
                            args.input_dir,
                            cache_file_name=args.cache_file_name,
                            content_image=args.content_image)
    gan = GAN(generator, discriminator, **vars(args))
    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()
Esempio n. 7
0
def main():
    args = cmd.args()

    generator = make_generator(args.image_size, args.use_input_pose,
                               args.warp_agg, args.num_landmarks,
                               args.num_mask)
    generator.summary()
    if args.generator_checkpoint is not None:
        generator.load_weights(args.generator_checkpoint, by_name=True)

    discriminator = make_discriminator(args.image_size, args.use_input_pose,
                                       args.num_landmarks, args.num_mask)
    if args.discriminator_checkpoint is not None:
        discriminator.load_weights(args.discriminator_checkpoint)

    dataset = PoseHMDataset(test_phase=False, **vars(args))
    gan = CGAN(generator, discriminator, **vars(args))

    trainer = Trainer(dataset, gan, **vars(args))

    trainer.train()