Exemple #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('image_dir', type=str)
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--noise_dim', '-nd', type=int, default=128)
    parser.add_argument('--height', '-ht', type=int, default=128)
    parser.add_argument('--width', '-wd', type=int, default=128)
    parser.add_argument('--save_steps', '-ss', type=int, default=10)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=10)
    parser.add_argument('--lambda', '-l', type=float, default=10., dest='lmbd')
    parser.add_argument('--initial_steps', '-is', type=int, default=20)
    parser.add_argument('--initial_critics', '-sc', type=int, default=20)
    parser.add_argument('--normal_critics', '-nc', type=int, default=5)
    parser.add_argument('--model_dir', '-md', type=str, default="./params")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")
    parser.add_argument('--noise_mode', '-nm', type=str, default="uniform")
    parser.add_argument('--upsampling', '-up', type=str, default="subpixel")
    parser.add_argument('--dis_norm', '-dn', type=str, default=None)

    args = parser.parse_args()

    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    # output config to csv
    config_path = os.path.join(args.result_dir, "config.csv")
    dict_ = vars(args)
    df = pd.DataFrame(list(dict_.items()), columns=['attr', 'status'])
    df.to_csv(config_path, index=None)

    input_shape = (args.height, args.width, 3)

    image_sampler = ImageSampler(args.image_dir, target_size=input_shape[:2])
    noise_sampler = NoiseSampler(args.noise_mode)

    generator = Generator(args.noise_dim,
                          is_training=True,
                          upsampling=args.upsampling)
    discriminator = Discriminator(input_shape,
                                  is_training=True,
                                  normalization=args.dis_norm)

    wgan = WGAN(generator, discriminator, lambda_=args.lmbd, is_training=True)

    wgan.fit(image_sampler.flow(args.batch_size),
             noise_sampler,
             nb_epoch=args.nb_epoch,
             result_dir=args.result_dir,
             model_dir=args.model_dir,
             save_steps=args.save_steps,
             visualize_steps=args.visualize_steps,
             initial_steps=args.initial_steps,
             initial_critics=args.initial_critics,
             normal_critics=args.normal_critics)
Exemple #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--nb_epoch', '-e', type=int, default=1000)
    parser.add_argument('--latent_dim', '-ld', type=int, default=2)
    parser.add_argument('--height', '-ht', type=int, default=32)
    parser.add_argument('--width', '-wd', type=int, default=32)
    parser.add_argument('--channel', '-ch', type=int, default=1)
    parser.add_argument('--save_steps', '-ss', type=int, default=10)
    parser.add_argument('--visualize_steps', '-vs', type=int, default=10)
    parser.add_argument('--model_dir', '-md', type=str, default="./params")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")
    parser.add_argument('--noise_mode', '-nm', type=str, default="normal")

    args = parser.parse_args()

    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    dump_config(os.path.join(args.result_dir, 'config.csv'), args)
    input_shape = (args.height, args.width, args.channel)

    image_sampler = ImageSampler(
        target_size=(args.width, args.height),
        color_mode='rgb' if args.channel == 3 else 'gray',
        is_training=True)
    noise_sampler = NoiseSampler(args.noise_mode)

    autoencoder = AutoEncoder(
        input_shape,
        args.latent_dim,
        is_training=True,
        color_mode='rgb' if args.channel == 3 else 'gray')
    discriminator = Discriminator(is_training=True)

    aae = AAE(autoencoder, discriminator, is_training=True)

    train_x, _ = load_mnist(mode='training')
    aae.fit_generator(image_sampler.flow(train_x, batch_size=args.batch_size),
                      noise_sampler,
                      nb_epoch=args.nb_epoch,
                      save_steps=args.save_steps,
                      visualize_steps=args.visualize_steps,
                      result_dir=args.result_dir,
                      model_dir=args.model_dir)
Exemple #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-bs', type=int, default=64)
    parser.add_argument('--latent_dim', '-ld', type=int, default=2)
    parser.add_argument('--height', '-ht', type=int, default=32)
    parser.add_argument('--width', '-wd', type=int, default=32)
    parser.add_argument('--channel', '-ch', type=int, default=1)
    parser.add_argument('--model_path', '-mp', type=str, default="./params")
    parser.add_argument('--result_dir', '-rd', type=str, default="./result")

    args = parser.parse_args()
    os.makedirs(args.result_dir, exist_ok=True)

    input_shape = (args.height, args.width, args.channel)

    autoencoder = AutoEncoder(
        input_shape,
        args.latent_dim,
        is_training=False,
        color_mode='rgb' if args.channel == 3 else 'gray')
    discriminator = Discriminator(is_training=False)

    aae = AAE(autoencoder, discriminator, is_training=False)
    aae.restore(args.model_path)

    test_x, test_y = load_mnist(mode='test')
    image_sampler = ImageSampler(
        target_size=(args.width, args.height),
        color_mode='rgb' if args.channel == 3 else 'gray',
        is_training=False)
    encoded = aae.predict_latent_vectors_generator(
        image_sampler.flow(test_x, shuffle=False))

    df = pd.DataFrame({
        'z_1': encoded[:, 0],
        'z_2': encoded[:, 1],
        'label': test_y
    })
    df.plot(kind='scatter', x='z_1', y='z_2', c='label', cmap='Set1', s=10)
    plt.savefig(os.path.join(args.result_dir, 'scatter.png'))