Exemple #1
0
        latent_dim=LATENT_DIM,
        random_sampling=False
    )
    goptim_cfg = OptimConfig('Adam', lr=0.0002)
    doptim_cfg = OptimConfig('Adam', lr=0.0002)

    dataset_cfg = DatasetConfig('CelebA', dataset_len=DATASET_LEN, image_size=IMAGE_SIZE, preload_len=PRELOAD_LEN)
    loader_cfg = LoaderConfig('naive', batch_size=128, shuffle=True)

    gen_cfg = ModelConfig(
        'EDCG', input_size=LATENT_DIM,
        hidden_size=128, output_size=3,
        data_num=PRELOAD_LEN, out_64=IMAGE_SIZE == 64
    )
    dis_cfg = ModelConfig('DCD', input_size=3, hidden_size=64, output_size=1, out_64=IMAGE_SIZE == 64)

    gan_cfg = GanConfig(
        name='CelebAGAN', gen_cfg=gen_cfg, dis_cfg=dis_cfg,
        gen_step=1, dis_step=1, gan_epoch=EPOCHS, loader_cfg=loader_cfg,
        dataset_cfg=dataset_cfg, gloss_cfg=gloss_cfg, dloss_cfg=dloss_cfg,
        goptim_cfg=goptim_cfg, doptim_cfg=doptim_cfg, label_smooth=False,
        sampler_cfg=sampler_cfg, dist_loss=False, device=device
    )
    gan_schema = GanSchema()
    gan_desc = gan_schema.dump(gan_cfg)
    logger.save_cfg(gan_desc)

    gan = gan_cfg.get()
    gan.train(use_tqdm=True)
    gan.save('checkpoints', 'CelebA64')
Exemple #2
0
    dataset_cfg = DatasetConfig('MNIST', train=True, stack=False, along_width=False, size=32)
    loader_cfg = LoaderConfig('naive', batch_size=128, shuffle=True)

    gen_cfg = ModelConfig('EDCG', input_size=LATENT_DIM, hidden_size=128, output_size=1, data_num=DATASET_LEN)
    dis_cfg = ModelConfig('DCD', input_size=1, hidden_size=128, output_size=1)

    train_data_cfg = DatasetConfig('MNIST', stack=False, train=True, along_width=False)
    test_data_cfg = DatasetConfig('MNIST', stack=False, train=False, along_width=False)

    train_loader = LoaderConfig('naive', batch_size=128, shuffle=True)
    test_loader = LoaderConfig('naive', batch_size=128, shuffle=True)

    util_cfg = UtilityModelConfig('NaiveClassifier', False, '/home/bourgan/gan_dev/checkpoints/mnist_naive.pth.tar',
                                  15, train_loader, test_loader, train_data_cfg, test_data_cfg, 32, 'cuda')
    gan_cfg = GanConfig(
        name='StackMNISTLPG', gen_cfg=gen_cfg, dis_cfg=dis_cfg,
        gen_step=1, dis_step=1, gan_epoch=EPOCHS, loader_cfg=loader_cfg,
        dataset_cfg=dataset_cfg, gloss_cfg=gloss_cfg, dloss_cfg=dloss_cfg,
        goptim_cfg=goptim_cfg, doptim_cfg=doptim_cfg, label_smooth=False,
        sampler_cfg=sampler_cfg, dist_loss=False, device=device, util_cfg=util_cfg

    )
    gan_schema = GanSchema()
    gan_desc = gan_schema.dump(gan_cfg)
    logger.save_cfg(gan_desc)

    gan = gan_cfg.get()
    gan.train(use_tqdm=True)
    gan.save('checkpoints', 'mnist')
Exemple #3
0
    loader_cfg = LoaderConfig('naive', batch_size=128, shuffle=True)

    gen_cfg = ModelConfig('MLG',
                          input_size=latent_dim,
                          hidden_size=32,
                          output_size=2)
    dis_cfg = ModelConfig('MLD', input_size=2, hidden_size=32, output_size=1)

    gan_cfg = GanConfig(name='LearningPrior2',
                        gen_cfg=gen_cfg,
                        dis_cfg=dis_cfg,
                        gen_step=1,
                        dis_step=1,
                        gan_epoch=20000,
                        loader_cfg=loader_cfg,
                        dataset_cfg=dataset_cfg,
                        gloss_cfg=gloss_cfg,
                        dloss_cfg=dloss_cfg,
                        goptim_cfg=goptim_cfg,
                        doptim_cfg=doptim_cfg,
                        label_smooth=False,
                        sampler_cfg=sampler_cfg,
                        dist_loss=False,
                        device=device)
    gan_schema = GanSchema()
    gan_desc = gan_schema.dump(gan_cfg)
    logger.save_cfg(gan_desc)

    gan = gan_cfg.get()
    # th.nn.init.normal_(gan.gen.sample_matrix.weight)
    gan.train(use_tqdm=True)
Exemple #4
0
                          input_nc=3,
                          output_nc=3,
                          num_downs=6,
                          ngf=64,
                          use_resizeconv=True,
                          ex_label=True,
                          label_len=label_len)
    dis_cfg = ModelConfig('PatchDis', input_nc=6)

    gan_cfg = GanConfig(name='Pix2Pix',
                        gen_cfg=gen_cfg,
                        dis_cfg=dis_cfg,
                        gen_step=1,
                        dis_step=1,
                        gan_epoch=GAN_EPOCHS,
                        loader_cfg=loader_cfg,
                        dataset_cfg=dataset_cfg,
                        gloss_cfg=gloss_cfg,
                        dloss_cfg=dloss_cfg,
                        goptim_cfg=goptim_cfg,
                        doptim_cfg=doptim_cfg,
                        device=device)
    gan = gan_cfg.get()
    gan.load('./p2p/' + EXP_NAME)

    SAVE_DIR = './eval_result/' + EXP_NAME
    os.makedirs(SAVE_DIR, exist_ok=True)

    dirs = gan.eval_on_dataset(SAVE_DIR, label_len=label_len)
    dirs = [['/'.join(x.split('/')[-2:]) for x in d] for d in dirs]
    gen_cfg = ModelConfig('EDCG',
                          input_size=LATENT_DIM,
                          hidden_size=128,
                          output_size=1,
                          data_num=DATASET_LEN)
    dis_cfg = ModelConfig('DCD', input_size=1, hidden_size=128, output_size=1)

    gan_cfg = GanConfig(name='MNIST-WGAN-GP',
                        gen_cfg=gen_cfg,
                        dis_cfg=dis_cfg,
                        gen_step=1,
                        dis_step=1,
                        gan_epoch=EPOCHS,
                        loader_cfg=loader_cfg,
                        dataset_cfg=dataset_cfg,
                        gloss_cfg=gloss_cfg,
                        dloss_cfg=dloss_cfg,
                        goptim_cfg=goptim_cfg,
                        doptim_cfg=doptim_cfg,
                        label_smooth=False,
                        sampler_cfg=sampler_cfg,
                        dist_loss=False,
                        device=device,
                        LAMBDA=0.1)
    gan_schema = GanSchema()
    gan_desc = gan_schema.dump(gan_cfg)
    logger.save_cfg(gan_desc)

    gan = gan_cfg.get()
    gan.train(use_tqdm=True)
    gan.save('checkpoints', 'mnist')
Exemple #6
0
    sampler_cfg = SamplerConfig(
        name='Gaussian',
        out_shape=latent_dim,
    )

    goptim_cfg = OptimConfig('Adam', lr=1e-3)
    doptim_cfg = OptimConfig('Adam', lr=1e-3)

    dataset_cfg = DatasetConfig('Spiral', mode=mode, sig=1, num_per_mode=dataset_len // mode)
    # dataset_cfg = DatasetConfig('GMM', mode=mode, sig=1, num_per_mode=dataset_len // mode)
    loader_cfg = LoaderConfig('naive', batch_size=128, shuffle=True)

    gen_cfg = ModelConfig('MLG', input_size=latent_dim, hidden_size=32, output_size=2)
    dis_cfg = ModelConfig('MLD', input_size=2, hidden_size=32, output_size=1)

    gan_cfg = GanConfig(
        name='Vanilla', gen_cfg=gen_cfg, dis_cfg=dis_cfg,
        gen_step=1, dis_step=1, gan_epoch=15000, loader_cfg=loader_cfg,
        dataset_cfg=dataset_cfg, gloss_cfg=gloss_cfg, dloss_cfg=dloss_cfg,
        goptim_cfg=goptim_cfg, doptim_cfg=doptim_cfg, label_smooth=False,
        sampler_cfg=sampler_cfg, device=device
    )
    gan_schema = GanSchema()
    gan_desc = gan_schema.dump(gan_cfg)
    logger.save_cfg(gan_desc)

    gan = gan_cfg.get()
    gan.train(use_tqdm=True)