def d_r1_loss(real_pred, real_img): with conv2d_gradfix.no_weight_gradients(): grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_img, create_graph=True ) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty
def d_r1_loss(real_pred, real_img, args): if args.useConvdFix==True: print("I entered") from op import conv2d_gradfix with conv2d_gradfix.no_weight_gradients(): grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_img, create_graph=True ) else: grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_img, create_graph=True ) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty