def test_nvil(self): k, eta, toy_experiment, true_grad = self.get_params_get_true_grad() n_samples = 10000 nvil_grads = torch.zeros(n_samples) baseline_nn = bs_lib.BaselineNN(slen=5) # sample NVIL gradient estimator, check it is unbiased for i in range(n_samples): toy_experiment.set_parameter(eta) pm_loss = toy_experiment.get_pm_loss( topk=0, grad_estimator=bs_lib.nvil, grad_estimator_kwargs={'baseline_nn': baseline_nn}) pm_loss.backward() nvil_grads[i] = toy_experiment.eta.grad assert_close(true_grad, torch.mean(nvil_grads), tol=3 * torch.std(nvil_grads) / np.sqrt(n_samples))
'params': c_phi.parameters() }], lr=1e-2) elif args.grad_estimator == 'gumbel': print('gumbel anneal rate: ', args.gumbel_anneal_rate) grad_estimator = bs_lib.gumbel grad_estimator_kwargs = {'annealing_fun': lambda t : \ np.maximum(0.5, \ np.exp(- args.gumbel_anneal_rate* float(t) * \ len(train_loader.sampler) / args.batch_size)), 'straight_through': True} elif args.grad_estimator == 'nvil': grad_estimator = bs_lib.nvil baseline_nn = bs_lib.BaselineNN(slen=train_set[0]['image'].shape[-1]) grad_estimator_kwargs = {'baseline_nn': baseline_nn.to(device)} optimizer = optim.Adam([{ 'params': vae.pixel_attention.parameters(), 'lr': args.learning_rate, 'weight_decay': args.weight_decay }, { 'params': vae.mnist_vae.parameters(), 'lr': args.learning_rate, 'weight_decay': args.weight_decay }, { 'params': baseline_nn.parameters(), 'lr': args.learning_rate, 'weight_decay': args.weight_decay }])
bs_optimizer = optim.Adam([{'params': [temperature_param]}, {'params': c_phi.parameters()}], lr = 1e-2) elif args.grad_estimator == 'gumbel': grad_estimator = bs_lib.gumbel print('annealing rate: ', args.gumbel_anneal_rate) grad_estimator_kwargs = {'annealing_fun': lambda t : \ np.maximum(0.5, \ np.exp(- args.gumbel_anneal_rate* float(t) * \ len(train_loader_labeled.sampler) / args.batch_size)), 'straight_through': False} elif args.grad_estimator == 'nvil': grad_estimator = bs_lib.nvil baseline_nn = bs_lib.BaselineNN(slen = slen) grad_estimator_kwargs = {'baseline_nn': baseline_nn.to(device)} optimizer = optim.Adam([ {'params': classifier.parameters(), 'lr': args.learning_rate}, {'params': vae.parameters(), 'lr': args.learning_rate}, {'params': baseline_nn.parameters(), 'lr': args.learning_rate}], weight_decay=args.weight_decay) else: print('invalid gradient estimator') raise NotImplementedError # train! outfile = args.outdir + args.outfilename ss_lib.train_semisuper_vae(vae, classifier,