def test_nvil(self):
        k, eta, toy_experiment, true_grad = self.get_params_get_true_grad()

        n_samples = 10000

        nvil_grads = torch.zeros(n_samples)

        baseline_nn = bs_lib.BaselineNN(slen=5)
        # sample NVIL gradient estimator, check it is unbiased
        for i in range(n_samples):
            toy_experiment.set_parameter(eta)
            pm_loss = toy_experiment.get_pm_loss(
                topk=0,
                grad_estimator=bs_lib.nvil,
                grad_estimator_kwargs={'baseline_nn': baseline_nn})
            pm_loss.backward()

            nvil_grads[i] = toy_experiment.eta.grad

        assert_close(true_grad,
                     torch.mean(nvil_grads),
                     tol=3 * torch.std(nvil_grads) / np.sqrt(n_samples))
Esempio n. 2
0
        'params': c_phi.parameters()
    }],
                              lr=1e-2)

elif args.grad_estimator == 'gumbel':
    print('gumbel anneal rate: ', args.gumbel_anneal_rate)
    grad_estimator = bs_lib.gumbel
    grad_estimator_kwargs = {'annealing_fun': lambda t : \
                        np.maximum(0.5, \
                        np.exp(- args.gumbel_anneal_rate* float(t) * \
                            len(train_loader.sampler) / args.batch_size)),
                            'straight_through': True}

elif args.grad_estimator == 'nvil':
    grad_estimator = bs_lib.nvil
    baseline_nn = bs_lib.BaselineNN(slen=train_set[0]['image'].shape[-1])
    grad_estimator_kwargs = {'baseline_nn': baseline_nn.to(device)}

    optimizer = optim.Adam([{
        'params': vae.pixel_attention.parameters(),
        'lr': args.learning_rate,
        'weight_decay': args.weight_decay
    }, {
        'params': vae.mnist_vae.parameters(),
        'lr': args.learning_rate,
        'weight_decay': args.weight_decay
    }, {
        'params': baseline_nn.parameters(),
        'lr': args.learning_rate,
        'weight_decay': args.weight_decay
    }])
Esempio n. 3
0
    bs_optimizer = optim.Adam([{'params': [temperature_param]},
                            {'params': c_phi.parameters()}], lr = 1e-2)

elif args.grad_estimator == 'gumbel':
    grad_estimator = bs_lib.gumbel
    print('annealing rate: ', args.gumbel_anneal_rate)
    grad_estimator_kwargs = {'annealing_fun': lambda t : \
                        np.maximum(0.5, \
                        np.exp(- args.gumbel_anneal_rate* float(t) * \
                    len(train_loader_labeled.sampler) / args.batch_size)),
                    'straight_through': False}

elif args.grad_estimator == 'nvil':
    grad_estimator = bs_lib.nvil
    baseline_nn = bs_lib.BaselineNN(slen = slen)
    grad_estimator_kwargs = {'baseline_nn': baseline_nn.to(device)}

    optimizer = optim.Adam([
                    {'params': classifier.parameters(), 'lr': args.learning_rate},
                    {'params': vae.parameters(), 'lr': args.learning_rate},
                    {'params': baseline_nn.parameters(), 'lr': args.learning_rate}],
                    weight_decay=args.weight_decay)

else:
    print('invalid gradient estimator')
    raise NotImplementedError

# train!
outfile = args.outdir + args.outfilename
ss_lib.train_semisuper_vae(vae, classifier,