示例#1
0
def gradient_penalty(images, output, weight=10):
    batch_size = images.shape[0]
    gradients = torch_grad(
        outputs=output,
        inputs=images,
        grad_outputs=torch.ones(output.size()).cuda(),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
def calc_pl_lengths(styles, images):
    num_pixels = images.shape[2] * images.shape[3]
    pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels)
    outputs = (images * pl_noise).sum()

    pl_grads = torch_grad(outputs=outputs,
                          inputs=styles,
                          grad_outputs=torch.ones(outputs.shape).cuda(),
                          create_graph=True,
                          retain_graph=True,
                          only_inputs=True)[0]

    return (pl_grads**2).sum(dim=2).mean(dim=1).sqrt()
示例#3
0
def gradient_penalty(images, output, weight=10):
    batch_size, device = images.shape[0], images.device
    gradients = torch_grad(outputs=output,
                           inputs=images,
                           grad_outputs=torch.ones(output.size(),
                                                   device=device),
                           create_graph=True,
                           retain_graph=True,
                           only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    l2 = ((gradients.norm(2, dim=1) - 1)**2).mean()
    return weight * l2
示例#4
0
 def forward(self, styles, images, device):
     num_pixels = images.shape[2] * images.shape[3]
     pl_noise = torch.randn(
         images.shape).to(device=device) / math.sqrt(num_pixels)
     outputs = torch.sum(images * pl_noise)
     pl_grads = torch_grad(outputs=outputs,
                           inputs=styles,
                           grad_outputs=torch.ones(
                               outputs.shape).to(device=device),
                           create_graph=True,
                           retain_graph=True,
                           only_inputs=True)[0]
     return (pl_grads**2).sum().mean().sqrt()
def gradient_penalty(images, outputs, weight=10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=outputs,
                           inputs=images,
                           grad_outputs=list(
                               map(lambda t: torch.ones(t.size()).cuda(),
                                   outputs)),
                           create_graph=True,
                           retain_graph=True,
                           only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
示例#6
0
def CG_normaleq(params: List[Tensor],
                hparams: List[Tensor],
                K: int,
                fp_map: Callable[[List[Tensor], List[Tensor]], List[Tensor]],
                outer_loss: Callable[[List[Tensor], List[Tensor]], Tensor],
                tol=1e-10,
                set_grad=True) -> List[Tensor]:
    """ Similar to CG but the conjugate gradient is applied on the normal equation (has a higher time complexity)"""
    params = [w.detach().requires_grad_(True) for w in params]
    o_loss = outer_loss(params, hparams)
    grad_outer_w, grad_outer_hparams = get_outer_gradients(
        o_loss, params, hparams)

    w_mapped = fp_map(params, hparams)

    def dfp_map_dw(xs):
        Jfp_mapTv = torch_grad(w_mapped,
                               params,
                               grad_outputs=xs,
                               retain_graph=True)
        v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)]

        # normal equation part
        Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp(
            lambda _params: fp_map(_params, hparams), params,
            v_minus_Jfp_mapTv)
        return [
            v - vv for v, vv in zip(v_minus_Jfp_mapTv,
                                    Jfp_mapv_minus_Jfp_mapJfp_mapTv)
        ]

    v_minus_Jfp_mapv = [
        g - jfp_mapv for g, jfp_mapv in zip(
            grad_outer_w,
            jvp(lambda _params: fp_map(_params, hparams), params,
                grad_outer_w))
    ]
    vs = CG_torch.cg(dfp_map_dw, v_minus_Jfp_mapv, max_iter=K,
                     epsilon=tol)  # K steps of conjugate gradient

    grads = torch_grad(w_mapped, hparams, grad_outputs=vs, allow_unused=True)
    grads = [
        g + v if g is not None else v
        for g, v in zip(grads, grad_outer_hparams)
    ]

    if set_grad:
        update_tensor_grads(hparams, grads)

    return grads
示例#7
0
def neumann(params, hparams, K, fp_map, outer_loss, tol=1e-10, set_grad=True):
    # from https://arxiv.org/pdf/1803.06396.pdf,  should return the same gradient of fixed point K+1
    params = [w.detach().requires_grad_(True) for w in params]
    o_loss = outer_loss(params, hparams)
    grad_outer_w, grad_outer_hparams = get_outer_gradients(
        o_loss, params, hparams)

    w_mapped = fp_map(params, hparams)
    vs, gs = grad_outer_w, grad_outer_w
    gs_vec = cat_list_to_tensor(gs)
    for k in range(K):
        gs_prev_vec = gs_vec
        vs = torch_grad(w_mapped, params, grad_outputs=vs, retain_graph=True)
        gs = [g + v for g, v in zip(gs, vs)]
        gs_vec = cat_list_to_tensor(gs)
        if float(torch.norm(gs_vec - gs_prev_vec)) < tol:
            break

    grads = torch_grad(w_mapped, hparams, grad_outputs=gs)
    grads = [g + v for g, v in zip(grads, grad_outer_hparams)]
    if set_grad:
        update_tensor_grads(hparams, grads)
    return grads
示例#8
0
    def _gradient_penalty(self, real_data: torch.Tensor,
                          generated_data: torch.Tensor) -> float:
        """
        Computes gradient penalty for given pair of real and
        generated data.
        Applies gradient penalty weight as well.

        :param real_data: Example from training data.
        :param generated_data: Output from Generator.
        """
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)

        if self.use_cuda:
            alpha = alpha.cuda()

        interpolated = (alpha * real_data.data +
                        (1 - alpha) * generated_data.data)
        interpolated = Variable(interpolated, requires_grad=True)

        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).cuda()
            if self.use_cuda else torch.ones(prob_interpolated.size()),
            create_graph=True,
            retain_graph=True)[0]

        # Gradients have shape:
        # (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        gradnorm = gradients.norm(2, dim=1)
        self.losses['gradient_norm'].append(gradnorm.mean().data.item())
        # print(f"Saved gradnorm:
        #        f"{self.gp_weight * ((gradnorm - 1) ** 2).mean()}")

        # Return gradient penalty
        return self.gp_weight * ((gradnorm - 1)**2).mean()
示例#9
0
def CG_normaleq(params,
                hparams,
                K,
                fp_map,
                outer_loss,
                tol=1e-10,
                set_grad=True):
    params = [w.detach().requires_grad_(True) for w in params]
    o_loss = outer_loss(params, hparams)
    grad_outer_w, grad_outer_hparams = get_outer_gradients(
        o_loss, params, hparams)

    w_mapped = fp_map(params, hparams)

    def dfp_map_dw(xs):
        Jfp_mapTv = torch_grad(w_mapped,
                               params,
                               grad_outputs=xs,
                               retain_graph=True)
        v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)]

        # normal equation part
        Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp(
            lambda _params: fp_map(_params, hparams), params,
            v_minus_Jfp_mapTv)
        return [
            v - vv for v, vv in zip(v_minus_Jfp_mapTv,
                                    Jfp_mapv_minus_Jfp_mapJfp_mapTv)
        ]

    v_minus_Jfp_mapv = [
        g - jfp_mapv for g, jfp_mapv in zip(
            grad_outer_w,
            jvp(lambda _params: fp_map(_params, hparams), params,
                grad_outer_w))
    ]
    vs = CG_torch.cg(dfp_map_dw, v_minus_Jfp_mapv, max_iter=K,
                     epsilon=tol)  # K steps of conjugate gradient

    grads = torch_grad(w_mapped, hparams, grad_outputs=vs, allow_unused=True)
    grads = [
        g + v if g is not None else v
        for g, v in zip(grads, grad_outer_hparams)
    ]

    if set_grad:
        update_tensor_grads(hparams, grads)

    return grads
示例#10
0
    def dfp_map_dw(xs):
        Jfp_mapTv = torch_grad(w_mapped,
                               params,
                               grad_outputs=xs,
                               retain_graph=True)
        v_minus_Jfp_mapTv = [v - j for v, j in zip(xs, Jfp_mapTv)]

        # normal equation part
        Jfp_mapv_minus_Jfp_mapJfp_mapTv = jvp(
            lambda _params: fp_map(_params, hparams), params,
            v_minus_Jfp_mapTv)
        return [
            v - vv for v, vv in zip(v_minus_Jfp_mapTv,
                                    Jfp_mapv_minus_Jfp_mapJfp_mapTv)
        ]
def gradient_penalty(D, real_images, fake_images, weight=10):
    x_interp = interpolate(real_images, fake_images)
    x_interp.requires_grad_(True)
    o_interp = D(x_interp)

    grad = torch_grad(outputs=o_interp,
                      inputs=x_interp,
                      grad_outputs=torch.ones(o_interp.size()).cuda(),
                      create_graph=True,
                      retain_graph=True,
                      only_inputs=True)[0]

    grad = grad.view(grad.shape[0], -1)
    grad_norm, _ = torch.max(torch.abs(grad), 1)
    return weight * F.relu(grad_norm - 1).mean()
示例#12
0
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=output,
                           inputs=images,
                           grad_outputs=torch.ones(output.size(),
                                                   device=images.device),
                           create_graph=True,
                           retain_graph=True,
                           only_inputs=True)[0]

    flat_grad = gradients.reshape(batch_size, -1)
    penalty = weight * ((flat_grad.norm(2, dim=1) - 1)**2).mean()
    if return_structured_grads:
        return penalty, gradients
    else:
        return penalty
示例#13
0
def reverse(params_history: List[List[Tensor]],
            hparams: List[Tensor],
            update_map_history: List[Callable[[List[Tensor], List[Tensor]],
                                              List[Tensor]]],
            outer_loss: Callable[[List[Tensor], List[Tensor]], Tensor],
            set_grad=True) -> List[Tensor]:
    """
    Computes the hypergradient by recomputing and backpropagating through each inner update
    using the inner iterates and the update maps previously employed by the inner solver.
    Similarly to checkpointing, this allows to save memory w.r.t. reverse_unroll by increasing computation time.
    Truncated reverse can be performed by passing only part of the trajectory information, i.e. only the
    last k inner iterates and updates.

    Args:
        params_history: the inner iterates (from first to last)
        hparams: the outer variables (or hyperparameters), each element needs requires_grad=True
        update_map_history: updates used to solve the inner problem (from first to last)
        outer_loss: computes the outer objective taking parameters and hyperparameters as inputs
        set_grad: if True set t.grad to the hypergradient for every t in hparams

    Returns:
         the list of hypergradients for each element in hparams

    """
    params_history = [[w.detach().requires_grad_(True) for w in params]
                      for params in params_history]
    o_loss = outer_loss(params_history[-1], hparams)
    grad_outer_w, grad_outer_hparams = get_outer_gradients(
        o_loss, params_history[-1], hparams)

    alphas = grad_outer_w
    grads = [torch.zeros_like(w) for w in hparams]
    K = len(params_history) - 1
    for k in range(-2, -(K + 2), -1):
        w_mapped = update_map_history[k + 1](params_history[k], hparams)
        bs = grad_unused_zero(w_mapped,
                              hparams,
                              grad_outputs=alphas,
                              retain_graph=True)
        grads = [g + b for g, b in zip(grads, bs)]
        alphas = torch_grad(w_mapped, params_history[k], grad_outputs=alphas)

    grads = [g + v for g, v in zip(grads, grad_outer_hparams)]
    if set_grad:
        update_tensor_grads(hparams, grads)

    return grads
示例#14
0
文件: gaga.py 项目: LFetty/gaga
    def compute_gradient_penalty(self, real_data, fake_data):
        #https://github.com/EmilienDupont/wgan-gp/blob/master/training.py
        #https://github.com/caogang/wgan-gp/blob/master/gan_toy.py
        gpu = (str(self.device) != 'cpu')

        # alpha
        alpha = torch.rand(self.batch_size, 1)#, 1, 1)
        alpha = alpha.expand_as(real_data)
        if gpu:
            alpha = alpha.cuda()

        # interpolated
        interpolated = alpha * real_data + (1 - alpha) * fake_data
        interpolated = Variable(interpolated, requires_grad=True)
        if gpu:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # gradient
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).cuda() if gpu else torch.ones(
                                   prob_interpolated.size()),
                               create_graph=True, retain_graph=True, only_inputs=True)[0]

        # norm
        #LAMBDA = .1  # Smaller lambda seems to help for toy tasks specifically
        #gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Two sides penalty
        #gradient_penalty = ((gradients_norm - 1) ** 2).mean()

        # one side penalty
        a = torch.max(gradients_norm - 1, torch.zeros_like(gradients_norm))
        gradient_penalty = (a** 2).mean()

        # one side gradient penalty
        # replace
        # E((|∇f(αx_real −(1−α)x_fake)|−1)²)
        # by
        # (max(|∇f|−1,0))²
        #

        return gradient_penalty
示例#15
0
 def L_adv_loss(self, x_img, target_img, alpha, i):
     rands = (torch.rand(16, 1, 1, 1)).cuda(1)
     interpolated = rands * x_img + (1. - rands) * target_img
     interpolated = Variable(interpolated, requires_grad=True).cuda(1)
     ####      (1,128,128,3)
     ####  (1,2,2,1)
     #    discriminator(y_img,num_block=block,max_iter=self.max_iters,step=step_input[0],reuse=True)
     logit = D.discriminator(interpolated, block_num=i + 1, alpha=alpha)
     gradients = torch_grad(outputs=logit,
                            inputs=interpolated,
                            grad_outputs=torch.ones(logit.size()).cuda(1),
                            create_graph=True,
                            retain_graph=True)[0]
     #########################这是改进的WAGN#####################################333
     gradients = gradients.view(16, -1)
     gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)
     return 10 * ((gradients_norm - 1)**2).mean()
示例#16
0
        def get_fisher_info(n_samples=30):
            n_params = len(self.init_params)
            #looping so we don't run out of CUDA memory
            sums = [
                torch.zeros(p.shape).cuda()
                if self.use_cuda else torch.zeros(p.shape)
                for p in generator.parameters()
            ]

            for i in range(n_samples):
                sampled_data = self.sample_generator(1)
                log_probs = self.D(sampled_data)
                loss_grads = torch_grad(outputs=log_probs,
                                        inputs=list(generator.parameters()))
                for j in range(n_params):
                    sums[j] = sums[j] + loss_grads[j]**2

            return [s / n_samples for s in sums]
示例#17
0
def CG(params,
       hparams,
       K,
       fp_map,
       outer_loss,
       tol=1e-10,
       set_grad=True,
       stochastic=False):
    params = [w.detach().requires_grad_(True) for w in params]
    o_loss = outer_loss(params, hparams)
    grad_outer_w, grad_outer_hparams = get_outer_gradients(
        o_loss, params, hparams)

    if not stochastic:
        w_mapped = fp_map(params, hparams)

    def dfp_map_dw(xs):
        if stochastic:
            w_mapped_in = fp_map(params, hparams)
            Jfp_mapTv = torch_grad(w_mapped_in,
                                   params,
                                   grad_outputs=xs,
                                   retain_graph=False)
        else:
            Jfp_mapTv = torch_grad(w_mapped,
                                   params,
                                   grad_outputs=xs,
                                   retain_graph=True)
        return [v - j for v, j in zip(xs, Jfp_mapTv)]

    vs = CG_torch.cg(dfp_map_dw, grad_outer_w, max_iter=K,
                     epsilon=tol)  # K steps of conjugate gradient

    if stochastic:
        w_mapped = fp_map(params, hparams)

    grads = torch_grad(w_mapped, hparams, grad_outputs=vs)
    grads = [g + v for g, v in zip(grads, grad_outer_hparams)]

    if set_grad:
        update_tensor_grads(hparams, grads)

    return grads
示例#18
0
def calc_gradient_penalty(netD,
                          real_data,
                          generated_data,
                          cuda_available=False):
    # GP strength
    LAMBDA = 10

    b_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(b_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    if cuda_available:
        alpha = alpha.cuda()

    interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
    interpolated = Variable(interpolated, requires_grad=True)
    if cuda_available:
        interpolated = interpolated.cuda()

    # Calculate probability of interpolated examples
    prob_interpolated = netD(interpolated)

    # Calculate gradients of probabilities with respect to examples
    ones = torch.ones(prob_interpolated.size())
    if cuda_available:
        ones = ones.cuda()
    gradients = torch_grad(outputs=prob_interpolated,
                           inputs=interpolated,
                           grad_outputs=ones,
                           create_graph=True,
                           retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(b_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

    # Return gradient penalty
    return LAMBDA * ((gradients_norm - 1)**2).mean()
示例#19
0
def gradient_penalty(real_data, generated_data):
    batch_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1)
    alpha = alpha.expand_as(real_data).cuda()
    interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
    interpolated = Variable(interpolated, requires_grad=True).cuda()

    del alpha
    torch.cuda.empty_cache()

    # Calculate probability of interpolated examples
    prob_interpolated = D(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated,
                           inputs=interpolated,
                           grad_outputs=torch.ones(
                               prob_interpolated.size()).cuda(),
                           create_graph=True,
                           retain_graph=True,
                           allow_unused=True)[0].cuda()

    # print(gradients)
    # print(gradients.shape)

    gradients = gradients.contiguous()

    # print(gradients)
    # print(gradients.shape)

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

    # Return gradient penalty
    return gp_weight * ((gradients_norm - 1)**2).mean()
示例#20
0
    def _gradient_penalty(self, real_data, generated_data, aux_data):
        assert real_data.size() == generated_data.size(), (
            'real and generated mini batches must '
            'have same size ({a} and {b})').format(a=real_data.size(),
                                                   b=generated_data.size())
        batch_size = real_data.size(0)

        # Calculate interpolation
        alpha = torch.rand(batch_size,
                           *[1 for _ in range(real_data.dim() - 1)])
        #alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()

        interpolated = alpha * real_data.data + (1. -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate distance of interpolated examples
        d_interpolated = self.critic(interpolated, aux_x=aux_data)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(d_interpolated.size()).cuda()
            if self.use_cuda else torch.ones(d_interpolated.size()),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        ## Return gradient penalty
        #if self.verbose > 0:
        #    if i % self.print_every == 0:
        #        self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data)
        return ((gradients_norm - 1)**2).mean()
示例#21
0
def gradient_penalty(learner_sa, expert_sa, f):
    batch_size = expert_sa.size()[0]

    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand_as(expert_sa)

    interpolated = alpha * expert_sa.data + (1 - alpha) * learner_sa.data
    interpolated = Variable(interpolated, requires_grad=True)

    f_interpolated = f(interpolated.float())

    gradients = torch_grad(outputs=f_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(f_interpolated.size()),
                           create_graph=True, retain_graph=True)[0]

    gradients = gradients.view(batch_size, -1)
    norm = gradients.norm(2, dim=1).mean().item()
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
    # 2 * |f'(x_0)|
    return ((gradients_norm - 0.4) ** 2).mean()
def gradient_penalty(images, output, labels, weight=10):
    """
    Compute the gradient of the output, relatively to the images.

    :param images: input of the network.
    :param output: output of the network.
    :param labels: labels of the input, unused.
    :param weight: factor by wich multiply the penalty.
    :return:
    """
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=output,
                           inputs=images,
                           grad_outputs=torch.ones(output.size()).cuda(),
                           create_graph=True,
                           retain_graph=True,
                           only_inputs=True)[0]

    gradients = gradients.view(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1)**2).mean()
示例#23
0
    def calc_gradient_penalty(self, real_data, generated_data):

        #gp weight
        gp_weight = 10

        b_size = real_data.size()[0]

        # Calculate interpolation
        # Random weight term for interpolation between real and fake samples
        alpha = torch.rand(b_size, 1, 64, 64, 64)
        alpha = alpha.expand_as(real_data)
        alpha = alpha.cuda()

        # Get random interpolation between real and fake samples
        interpolated = alpha * real_data.data + (1 -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.netD(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated,
                               inputs=interpolated,
                               grad_outputs=torch.ones(
                                   prob_interpolated.size()).cuda(),
                               create_graph=True,
                               retain_graph=True,
                               only_inputs=True)[0]  ## only_inputs= True 수정

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(b_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        # Return gradient penalty
        return gp_weight * ((gradients_norm - 1)**2).mean()
示例#24
0
    def _gradient_penalty(self, real_data, generated_data):
        """Compute the gradient penalty for the current update.
        From https://github.com/EmilienDupont/wgan-gp/blob/master/training.py -> _gradient_penalty()
        """
        discriminator = self
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = discriminator(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).cuda()
            if self.use_cuda else torch.ones(prob_interpolated.size()),
            create_graph=True,
            retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon

        # gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.LAMBDA * ((gradients.norm(2, dim=1) - 1)**2).mean()
示例#25
0
def _gradient_penalty_centered_(c_real, model, gp_weight, center=0.):
    """Gradient penalty for the discriminator"""
    B = c_real.size(0)
    c_real.requires_grad_(True)

    # Calculate gradients of probabilities with respect to examples
    #make_dot(d).view()
    d = model.D(c_real)
    gradients = torch_grad(outputs=d,
                           inputs=c_real,
                           grad_outputs=torch.ones(d.size()).cuda(),
                           create_graph=True,
                           retain_graph=True)[0]

    # Gradients have shape (B, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.contiguous().view(B, -1)
    gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

    # Return gradient penalty
    return gp_weight * ((gradients_norm - center)**2).mean()
示例#26
0
    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)
        one = torch.ones(1)
        if self.use_cuda:
            alpha = alpha.cuda()
            one = one.cuda()
        interpolated = alpha * real_data.data + (one -
                                                 alpha) * generated_data.data
        interpolated.requires_grad_()
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).cuda()
            if self.use_cuda else torch.ones(prob_interpolated.size()),
            create_graph=True,
            retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(
            gradients.norm(2, dim=1).mean().item())

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1)**2).mean()
示例#27
0
def GradientPenalty(discriminator_model,
                    real_data,
                    generated_data,
                    gp_weight=10):
    batch_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    if torch.cuda.is_available():
        alpha = alpha.cuda()

    interpolated = alpha * real_data + (1 - alpha) * generated_data
    interpolated = Variable(interpolated, requires_grad=True)

    if torch.cuda.is_available():
        interpolated = interpolated.cuda()

    # Calculate probability of interpolated examples
    prob_interpolated = discriminator_model(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(
        outputs=prob_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones(prob_interpolated.size()).cuda()
        if torch.cuda.is_available() else torch.ones(prob_interpolated.size()),
        create_graph=True,
        retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

    # Return gradient penalty
    return gp_weight * ((gradients_norm - 1)**2).mean()
示例#28
0
    def _gradient_penalty(self, real_data, generated_data):

        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Pass interpolated data through Critic
        prob_interpolated = self.c(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).cuda()
            if self.use_cuda else torch.ones(prob_interpolated.size()),
            create_graph=True,
            retain_graph=True)[0]
        # Gradients have shape (batch_size, num_channels, series length),
        # here we flatten to take the norm per example for every batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(
            gradients.norm(2, dim=1).mean().data.item())

        # Derivatives of the gradient close to 0 can cause problems because of the
        # square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1)**2).mean()
示例#29
0
    def __call__(self, img: torch.Tensor,
                 cls_idx: int) -> (np.array, np.array, np.array):
        """
        img: input image used for making prediction
        cls_idx: class index visualized to show which region is focused by a model given

        return: heatmap, normalized image
        """
        # get output and feature map from CNN
        output = self.model(img)
        feature_maps = self.f_get_last_module(self.model).output

        # set zeros except index of target class
        output[:, [i for i in range(output.shape[1]) if i != cls_idx]] = 0

        # calc partial derivative of output w.r.t each cell of feature maps
        gradients = torch_grad(outputs=output,
                               inputs=feature_maps,
                               grad_outputs=torch.ones(
                                   output.size()).to(device=self.device),
                               create_graph=False,
                               retain_graph=True)[0]

        # calc sum per channel
        alpha = F.avg_pool2d(gradients, feature_maps.shape[2])

        # create localization map
        heatmap = F.relu(torch.sum(feature_maps * alpha, dim=1, keepdim=True))
        heatmap = cv2.resize(heatmap[0, 0, :, :].detach().cpu().numpy(),
                             (299, 299))
        # rescale to [0, 1]
        heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) -
                                                 np.min(heatmap))
        heatmap = np.expand_dims(heatmap, axis=0)
        # calc prob on outputs
        probs = F.softmax(output.detach(), dim=1).cpu().numpy()

        return heatmap, probs
示例#30
0
    def _grad_penalty(self, real_data, gen_data):
        batch_size = real_data.size()[0]
        t = torch.rand((batch_size, 1, 1), requires_grad=True)
        t = t.expand_as(real_data)

        if self.use_cuda:
            t = t.cuda()

        # mixed sample from real and fake; make approx of the 'true' gradient norm
        interpol = t * real_data.data + (1 - t) * gen_data.data

        if self.use_cuda:
            interpol = interpol.cuda()

        prob_interpol = self.D(interpol)
        torch.autograd.set_detect_anomaly(True)
        gradients = torch_grad(
            outputs=prob_interpol,
            inputs=interpol,
            grad_outputs=torch.ones(prob_interpol.size()).cuda()
            if self.use_cuda else torch.ones(prob_interpol.size()),
            create_graph=True,
            retain_graph=True)[0]
        gradients = gradients.view(batch_size, -1)
        #grad_norm = torch.norm(gradients, dim=1).mean()
        #self.losses['gradient_norm'].append(grad_norm.item())

        # add epsilon for stability
        eps = 1e-10
        gradients_norm = torch.sqrt(
            torch.sum(gradients**2, dim=1, dtype=torch.double) + eps)
        #gradients = gradients.cpu()
        # comment: precision is lower than grad_norm (think that is double) and gradients_norm is float
        return self.gp_weight * (torch.max(
            torch.zeros(1, dtype=torch.double).cuda()
            if self.use_cuda else torch.zeros(1, dtype=torch.double),
            gradients_norm.mean() - 1)**2), gradients_norm.mean().item()