def Hellinger(theta): # Example of usage of the code provided for answering Q2.5 as well as recommended hyper parameters. model = q2_model.Critic(2) optim = torch.optim.SGD(model.parameters(), lr=1e-3) sampler1 = iter(q2_sampler.distribution1(0, 512)) sampler2 = iter(q2_sampler.distribution1(theta, 512)) lambda_reg_lp = 50 # Recommended hyper parameters for the lipschitz regularizer. steps = 500 for step in range(steps): data1 = torch.from_numpy(next(sampler1)).float() data2 = torch.from_numpy(next(sampler2)).float() loss = -vf_squared_hellinger(data1, data2, model) print('Step {} : loss {}'.format(step, loss)) optim.zero_grad() loss.backward() optim.step() data1 = torch.from_numpy(next(sampler1)).float() data2 = torch.from_numpy(next(sampler2)).float() return vf_squared_hellinger(data1, data2, model)
def lp_reg(x, y, critic, device='cpu'): """ COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail. *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the equivalent of f in the paper. Also consider that the norm used is the L2 norm. This is important to consider, because we make the assumption that your implementation follows this notation when testing your function. *** :param x: (FloatTensor) - shape: (batchsize x 2) - Samples from a distribution P. :param y: (FloatTensor) - shape: (batchsize x 2) - Samples from a distribution Q. :param critic: (Module) - torch module that you want to regularize. :return: (FloatTensor) - shape: (1,) - Lipschitz penalty """ sampler = iter(q2_sampler.distribution1(0, x.size(0))) data = next(sampler) t = torch.tensor(data[:, 1], dtype=torch.float32).unsqueeze(1) x_hat = t*x + (1-t)*y x_hat.requires_grad = True fx_hat = critic(x_hat) fx_hat.backward(torch.ones_like(fx_hat)) grads = torch.autograd.grad(fx_hat, x_hat, grad_outputs=torch.ones_like(x_hat), create_graph=True) norm_grad = torch.norm(grads, dim=1, p=2) zero = torch.Tensor([0.]) lp = torch.max(zero, norm_grad - 1).pow(2) # sampler = iter(q2_sampler.distribution1(0, x.size(0))) # data = next(sampler) # t = torch.tensor(data[:, 1], requires_grad=True, dtype=torch.float32, device=device).view(-1, 1, 1, 1) # x_hat = t*x + (1-t)*y # fx_hat = critic(x_hat) # grads = torch.autograd.grad(fx_hat, x_hat, grad_outputs=torch.ones_like(fx_hat), create_graph=True)[0] # norm_grad = torch.norm(grads, dim=1, p=2) # zero = torch.Tensor([0.]).to(device) # lp = torch.max(zero, norm_grad - 1).pow(2) return lp.mean()
def lp_reg(x, y, critic): """ COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail. *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the equivalent of f in the paper. Also consider that the norm used is the L2 norm. This is important to consider, because we make the assumption that your implementation follows this notation when testing your function. *** :param x: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution P. :param y: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution Q. :param critic: (Module) - torch module that you want to regularize. :return: (FloatTensor) - shape: (1,) - Lipschitz penalty """ a = torch.from_numpy( next(iter(q2_sampler.distribution1(0, batch_size=x.size(0))))[:, 1]).float() z = x * a[:, None] + y * (1 - a[:, None]) z = torch.autograd.Variable(z, requires_grad=True) fz = critic(z) grad_z = torch.autograd.grad(outputs=fz, inputs=z, grad_outputs=torch.ones(fz.size()), create_graph=True, retain_graph=True)[0] grad_z = grad_z.view(grad_z.size(0), -1) out = torch.mean( torch.relu(torch.norm(grad_z, p=2, dim=-1, keepdim=True) - 1)**2, dim=0) return out
:param critic: (Module) - torch module used to compute the Wasserstein distance :return: (FloatTensor) - shape: (1,) - Estimate of the Wasserstein distance """ return torch.mean(critic(x)) - torch.mean(critic(y)) def vf_squared_hellinger(x, y, critic): """ Complete me. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail. *** The notation used for the parameters follow the one from Nowazin et al: https://arxiv.org/pdf/1606.00709.pdf In other word, x are samples from the distribution P and y are samples from the distribution Q. Please note that the Critic is unbounded. *** :param p: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution p. :param q: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution q. :param critic: (Module) - torch module used to compute the Squared Hellinger. :return: (FloatTensor) - shape: (1,) - Estimate of the Squared Hellinger """ return torch.mean(1 - torch.exp(-critic(x))) - torch.mean( (1 - torch.exp(-critic(y))) / (torch.exp(-critic(y)))) if __name__ == '__main__': # Example of usage of the code provided for answering Q2.5 as well as recommended hyper parameters. model = q2_model.Critic(2) optim = torch.optim.SGD(model.parameters(), lr=1e-3) sampler1 = iter(q2_sampler.distribution1(0, 512)) theta = 0 sampler2 = iter(q2_sampler.distribution1(theta, 512)) lambda_reg_lp = 50 # Recommended hyper parameters for the lipschitz regularizer.
wd_dict = {} for theta in thetas: model1 = q2_model.Critic(2).to(device) optim1 = torch.optim.SGD(model1.parameters(), lr=1e-3) model2 = q2_model.Critic(2).to(device) optim2 = torch.optim.SGD(model2.parameters(), lr=1e-3) lambda_reg_lp = 50 # Recommended hyper parameters for the lipschitz regularizer. iterations = 2500 for i in range(iterations): # print("iteration and theta is: ", i, " ", theta) ## data is the same for both the models sampler1 = iter(q2_sampler.distribution1(0, 512)) sampler2 = iter(q2_sampler.distribution1(theta, 512)) data1 = torch.from_numpy(next(sampler1)).to(device) data2 = torch.from_numpy(next(sampler2)).to(device) ## hellinger distance out1 = model1(data1.type(dtype)) out2 = model1(data2.type(dtype)) ## let's compute the loss! px = 1.0 - torch.exp(-out1) py = 1.0 - torch.exp(-out2) fpy = -py / (1.0 - py) loss_hellinger = -(torch.mean(px) + torch.mean(fpy)) optim1.zero_grad()