Exemplo n.º 1
0
def loss_encoder_hessian(E,
                         samples,
                         alpha,
                         scale_alpha=False,
                         hessian_layers=[-1, -2],
                         current_layer=[-1],
                         hessian_weight=0.01):
    loss = hessian_penalty(E,
                           z=samples,
                           alpha=alpha,
                           return_norm=hessian_layers)
    if current_layer in hessian_layers or scale_alpha:
        loss = loss * alpha
    return loss * hessian_weight
Exemplo n.º 2
0
    def forward(self,
                F,
                G,
                E,
                scale,
                alpha,
                z,
                labels=None,
                hessian_layers=[3],
                current_layer=[0]):
        F_z = F(z, scale, z2=None, p_mix=0)

        # Autoencoding loss in latent space
        G_z = G(F_z, scale, alpha)
        E_z = E(G_z, alpha)

        if labels is not None:
            E_z = E_z.reshape(E_z.shape[0], 1,
                              E_z.shape[1]).repeat(1, F_z.shape[1], 1)
            if self.use_dist:
                x = self.p_dist(F_z, E_z)
                y = torch.eq(labels, labels.T).float().to(x.device)
                loss = self.loss_fn(x, y)
            else:
                perm = torch.randperm(E_z.shape[0], device=E_z.device)
                E_z_hat = torch.index_select(E_z, 0, perm)
                F_z_hat = torch.index_select(F_z, 0, perm)
                F_hat = torch.cat([F_z, F_x_hat], 0)
                E_hat = torch.cat([E_z, E_z_hat], 0)
                loss = self.loss_fn(F_hat, E_hat, labels)
        else:
            F_x = F_z[:, 0, :]
            loss = self.loss_fn(F_x, E_z)

        if self.use_tv:
            loss += self.total_variation(G_z)

        # Hessian applied to G here
        if self.enable_hessian:
            h_loss = hessian_penalty(G,
                                     z=F_z,
                                     scale=scale,
                                     alpha=alpha,
                                     return_norm=hessian_layers)
            h_loss *= self.hessian_weight
            if current_layer in hessian_layers:
                h_loss = h_loss * alpha
            loss += h_loss
        return loss
Exemplo n.º 3
0
def loss_generator_hessian(G,
                           F,
                           z,
                           scale,
                           alpha,
                           scale_alpha=False,
                           hessian_layers=[3],
                           current_layer=[0],
                           hessian_weight=0.01):
    loss = hessian_penalty(G,
                           z=F(z, scale, z2=None, p_mix=0),
                           scale=scale,
                           alpha=alpha,
                           return_norm=hessian_layers)
    if current_layer in hessian_layers or scale_alpha:
        loss = loss * alpha
    return loss * hessian_weight
Exemplo n.º 4
0
def loss_generator(E,
                   D,
                   alpha,
                   fake_samples,
                   enable_hessian=True,
                   hessian_layers=[-1, -2],
                   current_layer=[-1],
                   hessian_weight=0.01):
    # Hessian applied to E here
    # Minimize negative = Maximize positive (Minimize correct D predictions for fake data)
    E_z = E(fake_samples, alpha)
    loss = softplus(-D(E_z)).mean()
    if enable_hessian:
        for layer in hessian_layers:
            h_loss = hessian_penalty(
                E, z=fake_samples, alpha=alpha,
                return_norm=layer) * hessian_weight
            if layer in current_layer:
                h_loss = h_loss * alpha
            loss += h_loss
    return loss
Exemplo n.º 5
0
 def forward(self,
             E: Module,
             D: Module,
             alpha: float,
             fake_samples: Tensor,
             enable_hessian=True,
             hessian_layers=[-1, -2],
             current_layer=[-1],
             hessian_weight=0.01):
     # Hessian applied to E here
     # Minimize negative = Maximize positive (Minimize correct D predictions for fake data)
     E_z = E(fake_samples, alpha)
     loss = F.softplus(-D(E_z)).mean()
     if enable_hessian:
         for layer in hessian_layers:
             # CALAE.loss.hessian_penalty
             h_loss = hessian_penalty(
                 E, z=fake_samples, alpha=alpha,
                 return_norm=layer) * hessian_weight
             if layer in current_layer:
                 h_loss = h_loss * alpha
             loss += h_loss
     return loss