def loss_fn(data): X, Y = data X = reconstruct(X) X_means, X_vars = utility.c_stats(X, Y, n_classes) loss_mean = (X_means - A_means).norm(dim=1).mean() loss_var = (X_vars - A_vars).norm(dim=1).mean() loss = loss_mean + loss_var info = { 'loss': loss, '[losses] mean': loss_mean.item(), '[losses] var': loss_var.item(), 'c-entropy': dataset.cross_entropy(X), } return info
print("Cross Entropy of A:", dataset.cross_entropy(X_A).item()) print("Cross Entropy of B:", dataset.cross_entropy(X_B).item()) # ======= reconstruct Model ======= A = torch.eye((2), requires_grad=True) b = torch.zeros((2), requires_grad=True) def reconstruct(X): return X @ A + b # ======= Collect Stats from A ======= # collect stats # shape: [n_class, n_dims] = [2, 2] A_means, A_vars = utility.c_stats(X_A, Y_A, n_classes) # ======= Loss Function ======= # def loss_frechet(X, Y=Y_B): # X_means, X_vars, _ = utility.c_mean_var(X, Y, n_classes) # diff_mean = ((X_means - A_means)**2).sum(dim=0).mean() # diff_var = (X_vars + A_vars - 2 * (X_vars * A_vars).sqrt() # ).sum(dim=0).mean() # loss = (diff_mean + diff_var) # return loss # def loss_fn(X, Y=Y_B): # # log likelihood * 2 - const: # # diff_mean = (((X_means - A_means)**2 / X_vars.detach())).sum(dim=0)
return torch.cat(layer_activations, dim=1) # ======= reconstruct Model ======= A = torch.eye((2), requires_grad=True) b = torch.zeros((2), requires_grad=True) def reconstruct(X): return X @ A + b # ======= Loss Function ======= with torch.no_grad(): X_A_proj = project(X_A) A_proj_means, A_proj_vars = utility.c_stats(X_A_proj, Y_A, n_classes) # def loss_frechet(X, Y=Y_B): # X_proj = project(X) # X_proj_means, X_proj_vars, _ = utility.c_mean_var(X_proj, Y, n_classes) # diff_mean = ((X_proj_means - A_proj_means)**2).sum(dim=0).mean() # diff_var = (X_proj_vars + A_proj_vars # - 2 * (X_proj_vars * A_proj_vars).sqrt() # ).sum(dim=0).mean() # loss = (diff_mean + diff_var) # return loss def loss_fn(data): X, Y = data X = reconstruct(X)