def loss_function(x,
                  x_hat,
                  q_z,
                  free_bits,
                  beta_rate,
                  global_step,
                  max_beta,
                  device,
                  weight_vec,
                  balance,
                  loss_weight,
                  lambda1=10,
                  lambda2=0.5):
    '''
    balance describes [weight for MAE, weight for KL, weight for l1]
    '''
    MAE = torch.tensor(
        weighted_MAE_loss(x,
                          x_hat,
                          include_candidate=True,
                          lambda1=lambda1,
                          lambda2=lambda2)).to(device)
    KL, bits = kl_loss_beta_vae(q_z, free_bits, beta_rate, global_step,
                                max_beta, device)

    l1 = torch.pow(weight_vec, 2)
    l1 = torch.sum(l1, dim=1)
    l1 = torch.pow(l1, -1)
    l1 = torch.sum(l1)

    raw_loss = balance[0] * MAE + balance[1] * KL + balance[2] * l1
    weighted_loss = raw_loss * loss_weight
    return weighted_loss, raw_loss, KL, bits, MAE, l1
Ejemplo n.º 2
0
def loss_function(x,
                  x_hat,
                  q_z,
                  free_bits,
                  beta_rate,
                  global_step,
                  max_beta,
                  device,
                  weight_vec,
                  balance,
                  loss_weight,
                  lambda1=10,
                  lambda2=0.5):
    '''
    NOTE: This function is outdated, it was utilized by 'VAE', and not to be used now without debugging.
    Currently utilizing 'simple_loss', defined below.

    =========== ARGUMENTS ===========  
        > x
        > x_hat
        > q_x
        > free_bits
        > beta_rate
        > global_step
        > max_beta
        > device
        > weight_vec
        > balance
        > loss_weight
        > lambda1
        > lambda2

        ============ RETURNS ============
        > weighted_loss
        > raw_loss
        > KL
        > bits
        > MAE
        > l1
    '''
    MAE = torch.tensor(
        weighted_MAE_loss(x,
                          x_hat,
                          include_candidate=True,
                          lambda1=lambda1,
                          lambda2=lambda2)).to(device)
    KL, bits = kl_loss_beta_vae(q_z, free_bits, beta_rate, global_step,
                                max_beta, device)

    l1 = torch.pow(weight_vec, 2)
    l1 = torch.sum(l1, dim=1)
    l1 = torch.pow(l1, -1)
    l1 = torch.sum(l1)

    raw_loss = balance[0] * MAE + balance[1] * KL + balance[2] * l1
    weighted_loss = raw_loss * loss_weight
    return weighted_loss, raw_loss, KL, bits, MAE, l1