コード例 #1
0
ファイル: evaluate.py プロジェクト: dntai/PAE
            flags.LEARNING_RATE = 1e-3
        else:
            flags.LEARNING_RATE = 2e-4  # follow radford
            flags.EXPONENTIAL_DECAY = True
        noise_dist = UniformNoise(flags.NOISE_DIM)
        input_noise_dist = UniformNoise(1)
        from veegan import VEEGAN
        model = VEEGAN(data_dist, noise_dist, input_noise_dist, flags, args)
    elif METHOD == 'wgan':
        noise_dist = NormalNoise(flags.NOISE_DIM)
        eps_dist = UniformNoise(1)
        from wgan import WGANGP
        model = WGANGP(data_dist, noise_dist, eps_dist, flags, args)
    elif METHOD == 'vae':
        from vae import VAE
        noise_dist = NormalNoise(flags.NOISE_DIM)
        model = VAE(data_dist, noise_dist, flags, args)
    elif METHOD == 'aae':
        from aae import AAE
        model = AAE(data_dist, noise_dist, flags, args)
    elif METHOD == 'avbac':
        from avbac import *
        noise_dist = NormalNoise(flags.NOISE_DIM)
        single_noise_dist = NormalNoise(PERTURB)
        model = AVB_AC(data_dist, noise_dist, single_noise_dist, flags, args)

    model.create_model()

    evaluator = Evaluator(model)
    evaluator.run(os.path.join(args.working_dir, 'model'), args.ckpt_id)