def gradient_penalty(images, output, weight=10): batch_size = images.shape[0] gradients = torch_grad( outputs=output, inputs=images, grad_outputs=torch.ones(output.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
def calc_pl_lengths(styles, images): num_pixels = images.shape[2] * images.shape[3] pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels) outputs = (images * pl_noise).sum() pl_grads = torch_grad(outputs=outputs, inputs=styles, grad_outputs=torch.ones(outputs.shape).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] return (pl_grads**2).sum(dim=2).mean(dim=1).sqrt()
def gradient_penalty(images, output, weight=10): batch_size, device = images.shape[0], images.device gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.reshape(batch_size, -1) l2 = ((gradients.norm(2, dim=1) - 1)**2).mean() return weight * l2
def forward(self, styles, images, device): num_pixels = images.shape[2] * images.shape[3] pl_noise = torch.randn( images.shape).to(device=device) / math.sqrt(num_pixels) outputs = torch.sum(images * pl_noise) pl_grads = torch_grad(outputs=outputs, inputs=styles, grad_outputs=torch.ones( outputs.shape).to(device=device), create_graph=True, retain_graph=True, only_inputs=True)[0] return (pl_grads**2).sum().mean().sqrt()
def gradient_penalty(images, outputs, weight=10): batch_size = images.shape[0] gradients = torch_grad(outputs=outputs, inputs=images, grad_outputs=list( map(lambda t: torch.ones(t.size()).cuda(), outputs)), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.reshape(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
def CG_normaleq(params: List[Tensor], hparams: List[Tensor], K: int, fp_map: Callable[[List[Tensor], List[Tensor]], List[Tensor]], outer_loss: Callable[[List[Tensor], List[Tensor]], Tensor], tol=1e-10, set_grad=True) -> List[Tensor]: """ Similar to CG but the conjugate gradient is applied on the normal equation (has a higher time complexity)""" params = [w.detach().requires_grad_(True) for w in params] o_loss = outer_loss(params, hparams) grad_outer_w, grad_outer_hparams = get_outer_gradients( o_loss, params, hparams) w_mapped = fp_map(params, hparams) def dfp_map_dw(xs): Jfp_mapTv = torch_grad(w_mapped, params, grad_outputs=xs, retain_graph=True) v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)] # normal equation part Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp( lambda _params: fp_map(_params, hparams), params, v_minus_Jfp_mapTv) return [ v - vv for v, vv in zip(v_minus_Jfp_mapTv, Jfp_mapv_minus_Jfp_mapJfp_mapTv) ] v_minus_Jfp_mapv = [ g - jfp_mapv for g, jfp_mapv in zip( grad_outer_w, jvp(lambda _params: fp_map(_params, hparams), params, grad_outer_w)) ] vs = CG_torch.cg(dfp_map_dw, v_minus_Jfp_mapv, max_iter=K, epsilon=tol) # K steps of conjugate gradient grads = torch_grad(w_mapped, hparams, grad_outputs=vs, allow_unused=True) grads = [ g + v if g is not None else v for g, v in zip(grads, grad_outer_hparams) ] if set_grad: update_tensor_grads(hparams, grads) return grads
def neumann(params, hparams, K, fp_map, outer_loss, tol=1e-10, set_grad=True): # from https://arxiv.org/pdf/1803.06396.pdf, should return the same gradient of fixed point K+1 params = [w.detach().requires_grad_(True) for w in params] o_loss = outer_loss(params, hparams) grad_outer_w, grad_outer_hparams = get_outer_gradients( o_loss, params, hparams) w_mapped = fp_map(params, hparams) vs, gs = grad_outer_w, grad_outer_w gs_vec = cat_list_to_tensor(gs) for k in range(K): gs_prev_vec = gs_vec vs = torch_grad(w_mapped, params, grad_outputs=vs, retain_graph=True) gs = [g + v for g, v in zip(gs, vs)] gs_vec = cat_list_to_tensor(gs) if float(torch.norm(gs_vec - gs_prev_vec)) < tol: break grads = torch_grad(w_mapped, hparams, grad_outputs=gs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] if set_grad: update_tensor_grads(hparams, grads) return grads
def _gradient_penalty(self, real_data: torch.Tensor, generated_data: torch.Tensor) -> float: """ Computes gradient penalty for given pair of real and generated data. Applies gradient penalty weight as well. :param real_data: Example from training data. :param generated_data: Output from Generator. """ batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1, 1, 1) alpha = alpha.expand_as(real_data) if self.use_cuda: alpha = alpha.cuda() interpolated = (alpha * real_data.data + (1 - alpha) * generated_data.data) interpolated = Variable(interpolated, requires_grad=True) if self.use_cuda: interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = self.D(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones(prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape: # (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) gradnorm = gradients.norm(2, dim=1) self.losses['gradient_norm'].append(gradnorm.mean().data.item()) # print(f"Saved gradnorm: # f"{self.gp_weight * ((gradnorm - 1) ** 2).mean()}") # Return gradient penalty return self.gp_weight * ((gradnorm - 1)**2).mean()
def CG_normaleq(params, hparams, K, fp_map, outer_loss, tol=1e-10, set_grad=True): params = [w.detach().requires_grad_(True) for w in params] o_loss = outer_loss(params, hparams) grad_outer_w, grad_outer_hparams = get_outer_gradients( o_loss, params, hparams) w_mapped = fp_map(params, hparams) def dfp_map_dw(xs): Jfp_mapTv = torch_grad(w_mapped, params, grad_outputs=xs, retain_graph=True) v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)] # normal equation part Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp( lambda _params: fp_map(_params, hparams), params, v_minus_Jfp_mapTv) return [ v - vv for v, vv in zip(v_minus_Jfp_mapTv, Jfp_mapv_minus_Jfp_mapJfp_mapTv) ] v_minus_Jfp_mapv = [ g - jfp_mapv for g, jfp_mapv in zip( grad_outer_w, jvp(lambda _params: fp_map(_params, hparams), params, grad_outer_w)) ] vs = CG_torch.cg(dfp_map_dw, v_minus_Jfp_mapv, max_iter=K, epsilon=tol) # K steps of conjugate gradient grads = torch_grad(w_mapped, hparams, grad_outputs=vs, allow_unused=True) grads = [ g + v if g is not None else v for g, v in zip(grads, grad_outer_hparams) ] if set_grad: update_tensor_grads(hparams, grads) return grads
def dfp_map_dw(xs): Jfp_mapTv = torch_grad(w_mapped, params, grad_outputs=xs, retain_graph=True) v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)] # normal equation part Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp( lambda _params: fp_map(_params, hparams), params, v_minus_Jfp_mapTv) return [ v - vv for v, vv in zip(v_minus_Jfp_mapTv, Jfp_mapv_minus_Jfp_mapJfp_mapTv) ]
def gradient_penalty(D, real_images, fake_images, weight=10): x_interp = interpolate(real_images, fake_images) x_interp.requires_grad_(True) o_interp = D(x_interp) grad = torch_grad(outputs=o_interp, inputs=x_interp, grad_outputs=torch.ones(o_interp.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] grad = grad.view(grad.shape[0], -1) grad_norm, _ = torch.max(torch.abs(grad), 1) return weight * F.relu(grad_norm - 1).mean()
def gradient_penalty(images, output, weight=10, return_structured_grads=False): batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device), create_graph=True, retain_graph=True, only_inputs=True)[0] flat_grad = gradients.reshape(batch_size, -1) penalty = weight * ((flat_grad.norm(2, dim=1) - 1)**2).mean() if return_structured_grads: return penalty, gradients else: return penalty
def reverse(params_history: List[List[Tensor]], hparams: List[Tensor], update_map_history: List[Callable[[List[Tensor], List[Tensor]], List[Tensor]]], outer_loss: Callable[[List[Tensor], List[Tensor]], Tensor], set_grad=True) -> List[Tensor]: """ Computes the hypergradient by recomputing and backpropagating through each inner update using the inner iterates and the update maps previously employed by the inner solver. Similarly to checkpointing, this allows to save memory w.r.t. reverse_unroll by increasing computation time. Truncated reverse can be performed by passing only part of the trajectory information, i.e. only the last k inner iterates and updates. Args: params_history: the inner iterates (from first to last) hparams: the outer variables (or hyperparameters), each element needs requires_grad=True update_map_history: updates used to solve the inner problem (from first to last) outer_loss: computes the outer objective taking parameters and hyperparameters as inputs set_grad: if True set t.grad to the hypergradient for every t in hparams Returns: the list of hypergradients for each element in hparams """ params_history = [[w.detach().requires_grad_(True) for w in params] for params in params_history] o_loss = outer_loss(params_history[-1], hparams) grad_outer_w, grad_outer_hparams = get_outer_gradients( o_loss, params_history[-1], hparams) alphas = grad_outer_w grads = [torch.zeros_like(w) for w in hparams] K = len(params_history) - 1 for k in range(-2, -(K + 2), -1): w_mapped = update_map_history[k + 1](params_history[k], hparams) bs = grad_unused_zero(w_mapped, hparams, grad_outputs=alphas, retain_graph=True) grads = [g + b for g, b in zip(grads, bs)] alphas = torch_grad(w_mapped, params_history[k], grad_outputs=alphas) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] if set_grad: update_tensor_grads(hparams, grads) return grads
def compute_gradient_penalty(self, real_data, fake_data): #https://github.com/EmilienDupont/wgan-gp/blob/master/training.py #https://github.com/caogang/wgan-gp/blob/master/gan_toy.py gpu = (str(self.device) != 'cpu') # alpha alpha = torch.rand(self.batch_size, 1)#, 1, 1) alpha = alpha.expand_as(real_data) if gpu: alpha = alpha.cuda() # interpolated interpolated = alpha * real_data + (1 - alpha) * fake_data interpolated = Variable(interpolated, requires_grad=True) if gpu: interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = self.D(interpolated) # gradient gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if gpu else torch.ones( prob_interpolated.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] # norm #LAMBDA = .1 # Smaller lambda seems to help for toy tasks specifically #gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) # Two sides penalty #gradient_penalty = ((gradients_norm - 1) ** 2).mean() # one side penalty a = torch.max(gradients_norm - 1, torch.zeros_like(gradients_norm)) gradient_penalty = (a** 2).mean() # one side gradient penalty # replace # E((|∇f(αx_real −(1−α)x_fake)|−1)²) # by # (max(|∇f|−1,0))² # return gradient_penalty
def L_adv_loss(self, x_img, target_img, alpha, i): rands = (torch.rand(16, 1, 1, 1)).cuda(1) interpolated = rands * x_img + (1. - rands) * target_img interpolated = Variable(interpolated, requires_grad=True).cuda(1) #### (1,128,128,3) #### (1,2,2,1) # discriminator(y_img,num_block=block,max_iter=self.max_iters,step=step_input[0],reuse=True) logit = D.discriminator(interpolated, block_num=i + 1, alpha=alpha) gradients = torch_grad(outputs=logit, inputs=interpolated, grad_outputs=torch.ones(logit.size()).cuda(1), create_graph=True, retain_graph=True)[0] #########################这是改进的WAGN#####################################333 gradients = gradients.view(16, -1) gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) return 10 * ((gradients_norm - 1)**2).mean()
def get_fisher_info(n_samples=30): n_params = len(self.init_params) #looping so we don't run out of CUDA memory sums = [ torch.zeros(p.shape).cuda() if self.use_cuda else torch.zeros(p.shape) for p in generator.parameters() ] for i in range(n_samples): sampled_data = self.sample_generator(1) log_probs = self.D(sampled_data) loss_grads = torch_grad(outputs=log_probs, inputs=list(generator.parameters())) for j in range(n_params): sums[j] = sums[j] + loss_grads[j]**2 return [s / n_samples for s in sums]
def CG(params, hparams, K, fp_map, outer_loss, tol=1e-10, set_grad=True, stochastic=False): params = [w.detach().requires_grad_(True) for w in params] o_loss = outer_loss(params, hparams) grad_outer_w, grad_outer_hparams = get_outer_gradients( o_loss, params, hparams) if not stochastic: w_mapped = fp_map(params, hparams) def dfp_map_dw(xs): if stochastic: w_mapped_in = fp_map(params, hparams) Jfp_mapTv = torch_grad(w_mapped_in, params, grad_outputs=xs, retain_graph=False) else: Jfp_mapTv = torch_grad(w_mapped, params, grad_outputs=xs, retain_graph=True) return [v - j for v, j in zip(xs, Jfp_mapTv)] vs = CG_torch.cg(dfp_map_dw, grad_outer_w, max_iter=K, epsilon=tol) # K steps of conjugate gradient if stochastic: w_mapped = fp_map(params, hparams) grads = torch_grad(w_mapped, hparams, grad_outputs=vs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] if set_grad: update_tensor_grads(hparams, grads) return grads
def calc_gradient_penalty(netD, real_data, generated_data, cuda_available=False): # GP strength LAMBDA = 10 b_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(b_size, 1, 1, 1) alpha = alpha.expand_as(real_data) if cuda_available: alpha = alpha.cuda() interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True) if cuda_available: interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = netD(interpolated) # Calculate gradients of probabilities with respect to examples ones = torch.ones(prob_interpolated.size()) if cuda_available: ones = ones.cuda() gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=ones, create_graph=True, retain_graph=True)[0] # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(b_size, -1) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return LAMBDA * ((gradients_norm - 1)**2).mean()
def gradient_penalty(real_data, generated_data): batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1, 1) alpha = alpha.expand_as(real_data).cuda() interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True).cuda() del alpha torch.cuda.empty_cache() # Calculate probability of interpolated examples prob_interpolated = D(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones( prob_interpolated.size()).cuda(), create_graph=True, retain_graph=True, allow_unused=True)[0].cuda() # print(gradients) # print(gradients.shape) gradients = gradients.contiguous() # print(gradients) # print(gradients.shape) # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return gp_weight * ((gradients_norm - 1)**2).mean()
def _gradient_penalty(self, real_data, generated_data, aux_data): assert real_data.size() == generated_data.size(), ( 'real and generated mini batches must ' 'have same size ({a} and {b})').format(a=real_data.size(), b=generated_data.size()) batch_size = real_data.size(0) # Calculate interpolation alpha = torch.rand(batch_size, *[1 for _ in range(real_data.dim() - 1)]) #alpha = alpha.expand_as(real_data) if self.use_cuda: alpha = alpha.cuda() interpolated = alpha * real_data.data + (1. - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True) if self.use_cuda: interpolated = interpolated.cuda() # Calculate distance of interpolated examples d_interpolated = self.critic(interpolated, aux_x=aux_data) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=d_interpolated, inputs=interpolated, grad_outputs=torch.ones(d_interpolated.size()).cuda() if self.use_cuda else torch.ones(d_interpolated.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) ## Return gradient penalty #if self.verbose > 0: # if i % self.print_every == 0: # self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data) return ((gradients_norm - 1)**2).mean()
def gradient_penalty(learner_sa, expert_sa, f): batch_size = expert_sa.size()[0] alpha = torch.rand(batch_size, 1) alpha = alpha.expand_as(expert_sa) interpolated = alpha * expert_sa.data + (1 - alpha) * learner_sa.data interpolated = Variable(interpolated, requires_grad=True) f_interpolated = f(interpolated.float()) gradients = torch_grad(outputs=f_interpolated, inputs=interpolated, grad_outputs=torch.ones(f_interpolated.size()), create_graph=True, retain_graph=True)[0] gradients = gradients.view(batch_size, -1) norm = gradients.norm(2, dim=1).mean().item() gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) # 2 * |f'(x_0)| return ((gradients_norm - 0.4) ** 2).mean()
def gradient_penalty(images, output, labels, weight=10): """ Compute the gradient of the output, relatively to the images. :param images: input of the network. :param output: output of the network. :param labels: labels of the input, unused. :param weight: factor by wich multiply the penalty. :return: """ batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
def calc_gradient_penalty(self, real_data, generated_data): #gp weight gp_weight = 10 b_size = real_data.size()[0] # Calculate interpolation # Random weight term for interpolation between real and fake samples alpha = torch.rand(b_size, 1, 64, 64, 64) alpha = alpha.expand_as(real_data) alpha = alpha.cuda() # Get random interpolation between real and fake samples interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True) interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = self.netD(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones( prob_interpolated.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] ## only_inputs= True 수정 # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(b_size, -1) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return gp_weight * ((gradients_norm - 1)**2).mean()
def _gradient_penalty(self, real_data, generated_data): """Compute the gradient penalty for the current update. From https://github.com/EmilienDupont/wgan-gp/blob/master/training.py -> _gradient_penalty() """ discriminator = self batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1) alpha = alpha.expand_as(real_data) if self.use_cuda: alpha = alpha.cuda() interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True) if self.use_cuda: interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = discriminator(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones(prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon # gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) # Return gradient penalty return self.LAMBDA * ((gradients.norm(2, dim=1) - 1)**2).mean()
def _gradient_penalty_centered_(c_real, model, gp_weight, center=0.): """Gradient penalty for the discriminator""" B = c_real.size(0) c_real.requires_grad_(True) # Calculate gradients of probabilities with respect to examples #make_dot(d).view() d = model.D(c_real) gradients = torch_grad(outputs=d, inputs=c_real, grad_outputs=torch.ones(d.size()).cuda(), create_graph=True, retain_graph=True)[0] # Gradients have shape (B, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.contiguous().view(B, -1) gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return gp_weight * ((gradients_norm - center)**2).mean()
def _gradient_penalty(self, real_data, generated_data): batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1, 1, 1) alpha = alpha.expand_as(real_data) one = torch.ones(1) if self.use_cuda: alpha = alpha.cuda() one = one.cuda() interpolated = alpha * real_data.data + (one - alpha) * generated_data.data interpolated.requires_grad_() if self.use_cuda: interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = self.D(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones(prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) self.losses['gradient_norm'].append( gradients.norm(2, dim=1).mean().item()) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return self.gp_weight * ((gradients_norm - 1)**2).mean()
def GradientPenalty(discriminator_model, real_data, generated_data, gp_weight=10): batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1, 1, 1) alpha = alpha.expand_as(real_data) if torch.cuda.is_available(): alpha = alpha.cuda() interpolated = alpha * real_data + (1 - alpha) * generated_data interpolated = Variable(interpolated, requires_grad=True) if torch.cuda.is_available(): interpolated = interpolated.cuda() # Calculate probability of interpolated examples prob_interpolated = discriminator_model(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if torch.cuda.is_available() else torch.ones(prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape (batch_size, num_channels, img_width, img_height), # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return gp_weight * ((gradients_norm - 1)**2).mean()
def _gradient_penalty(self, real_data, generated_data): batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1) alpha = alpha.expand_as(real_data) if self.use_cuda: alpha = alpha.cuda() interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data interpolated = Variable(interpolated, requires_grad=True) if self.use_cuda: interpolated = interpolated.cuda() # Pass interpolated data through Critic prob_interpolated = self.c(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad( outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones(prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape (batch_size, num_channels, series length), # here we flatten to take the norm per example for every batch gradients = gradients.view(batch_size, -1) self.losses['gradient_norm'].append( gradients.norm(2, dim=1).mean().data.item()) # Derivatives of the gradient close to 0 can cause problems because of the # square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return self.gp_weight * ((gradients_norm - 1)**2).mean()
def __call__(self, img: torch.Tensor, cls_idx: int) -> (np.array, np.array, np.array): """ img: input image used for making prediction cls_idx: class index visualized to show which region is focused by a model given return: heatmap, normalized image """ # get output and feature map from CNN output = self.model(img) feature_maps = self.f_get_last_module(self.model).output # set zeros except index of target class output[:, [i for i in range(output.shape[1]) if i != cls_idx]] = 0 # calc partial derivative of output w.r.t each cell of feature maps gradients = torch_grad(outputs=output, inputs=feature_maps, grad_outputs=torch.ones( output.size()).to(device=self.device), create_graph=False, retain_graph=True)[0] # calc sum per channel alpha = F.avg_pool2d(gradients, feature_maps.shape[2]) # create localization map heatmap = F.relu(torch.sum(feature_maps * alpha, dim=1, keepdim=True)) heatmap = cv2.resize(heatmap[0, 0, :, :].detach().cpu().numpy(), (299, 299)) # rescale to [0, 1] heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) heatmap = np.expand_dims(heatmap, axis=0) # calc prob on outputs probs = F.softmax(output.detach(), dim=1).cpu().numpy() return heatmap, probs
def _grad_penalty(self, real_data, gen_data): batch_size = real_data.size()[0] t = torch.rand((batch_size, 1, 1), requires_grad=True) t = t.expand_as(real_data) if self.use_cuda: t = t.cuda() # mixed sample from real and fake; make approx of the 'true' gradient norm interpol = t * real_data.data + (1 - t) * gen_data.data if self.use_cuda: interpol = interpol.cuda() prob_interpol = self.D(interpol) torch.autograd.set_detect_anomaly(True) gradients = torch_grad( outputs=prob_interpol, inputs=interpol, grad_outputs=torch.ones(prob_interpol.size()).cuda() if self.use_cuda else torch.ones(prob_interpol.size()), create_graph=True, retain_graph=True)[0] gradients = gradients.view(batch_size, -1) #grad_norm = torch.norm(gradients, dim=1).mean() #self.losses['gradient_norm'].append(grad_norm.item()) # add epsilon for stability eps = 1e-10 gradients_norm = torch.sqrt( torch.sum(gradients**2, dim=1, dtype=torch.double) + eps) #gradients = gradients.cpu() # comment: precision is lower than grad_norm (think that is double) and gradients_norm is float return self.gp_weight * (torch.max( torch.zeros(1, dtype=torch.double).cuda() if self.use_cuda else torch.zeros(1, dtype=torch.double), gradients_norm.mean() - 1)**2), gradients_norm.mean().item()