예제 #1
0
    def load_model_and_dataset(checkpt_filename):
        checkpt = torch.load(checkpt_filename)
        args = checkpt['args']
        state_dict = checkpt['state_dict']

        # backwards compatibility
        if not hasattr(args, 'conv'):
            args.conv = False

        from vae_quant import VAE, setup_data_loaders

        # model
        if args.dist == 'normal':
            prior_dist = dist.Normal()
            q_dist = dist.Normal()
        elif args.dist == 'laplace':
            prior_dist = dist.Laplace()
            q_dist = dist.Laplace()
        elif args.dist == 'flow':
            prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim,
                                                        nsteps=32)
            q_dist = dist.Normal()
        vae = VAE(z_dim=args.latent_dim,
                  use_cuda=True,
                  prior_dist=prior_dist,
                  q_dist=q_dist,
                  conv=args.conv)
        vae.load_state_dict(state_dict, strict=False)
        vae.eval()

        # dataset loader
        loader = setup_data_loaders(args, use_cuda=True)
        return vae, loader
예제 #2
0
    def load_model_and_dataset(checkpt_filename):
        print('Loading model and dataset.')
        checkpt = torch.load(checkpt_filename,
                             map_location=lambda storage, loc: storage)
        args = checkpt['args']
        state_dict = checkpt['state_dict']

        # model
        if not hasattr(args, 'dist') or args.dist == 'normal':
            prior_dist = dist.Normal()
            q_dist = dist.Normal()
        elif args.dist == 'laplace':
            prior_dist = dist.Laplace()
            q_dist = dist.Laplace()
        elif args.dist == 'flow':
            prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim,
                                                        nsteps=4)
            q_dist = dist.Normal()
        vae = VAE(z_dim=args.latent_dim,
                  use_cuda=True,
                  prior_dist=prior_dist,
                  q_dist=q_dist,
                  conv=args.conv)
        vae.load_state_dict(state_dict, strict=False)

        # dataset loader
        loader = setup_data_loaders(args)
        return vae, loader, args
예제 #3
0
def load_model_and_dataset(checkpt_filename):
    checkpt = torch.load(checkpt_filename)
    args = checkpt['args']
    state_dict = checkpt['state_dict']

    # backwards compatibility
    if not hasattr(args, 'conv'):
        args.conv = False

    x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli()
    a_dist = dist.Bernoulli()

    # model
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim,
                                                    nsteps=32)
        q_dist = dist.Normal()
    #vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)
    vae = SensVAE(z_dim=args.latent_dim,
                  use_cuda=True,
                  prior_dist=prior_dist,
                  q_dist=q_dist,
                  include_mutinfo=not args.exclude_mutinfo,
                  tcvae=args.tcvae,
                  conv=args.conv,
                  mss=args.mss,
                  n_chan=3 if args.dataset == 'celeba' else 1,
                  sens_idx=SENS_IDX,
                  x_dist=x_dist,
                  a_dist=a_dist)

    vae.load_state_dict(state_dict, strict=False)
    vae.beta = args.beta
    vae.beta_sens = args.beta_sens
    vae.eval()

    # dataset loader
    loader = setup_data_loaders(args, use_cuda=True)

    # test loader
    test_set = dset.CelebA(mode='test')
    kwargs = {'num_workers': 4, 'pin_memory': True}
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             **kwargs)

    return vae, loader, test_loader, args
예제 #4
0
def load_model_and_dataset(checkpt_filename):
    print('Loading model and dataset.')
    checkpt = torch.load(checkpt_filename,
                         map_location=lambda storage, loc: storage)
    args = checkpt['args']
    state_dict = checkpt['state_dict']

    # backwards compatibility
    if not hasattr(args, 'conv'):
        args.conv = False

    if not hasattr(args, 'dist') or args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim,
                                                    nsteps=32)
        q_dist = dist.Normal()

    # model
    if hasattr(args, 'ncon'):
        # InfoGAN
        model = infogan.Model(args.latent_dim,
                              n_con=args.ncon,
                              n_cat=args.ncat,
                              cat_dim=args.cat_dim,
                              use_cuda=True,
                              conv=args.conv)
        model.load_state_dict(state_dict, strict=False)
        vae = vae_quant.VAE(z_dim=args.ncon,
                            use_cuda=True,
                            prior_dist=prior_dist,
                            q_dist=q_dist,
                            conv=args.conv)
        vae.encoder = model.encoder
        vae.decoder = model.decoder
    else:
        vae = vae_quant.VAE(z_dim=args.latent_dim,
                            use_cuda=True,
                            prior_dist=prior_dist,
                            q_dist=q_dist,
                            conv=args.conv)
        vae.load_state_dict(state_dict, strict=False)

    # dataset loader
    loader = vae_quant.setup_data_loaders(args)
    return vae, loader.dataset, args