def _inversion_coef(self, constants):
     X, y, permuted_y, XTX, XTy, XTmy = constants.values()
     feature_sparsity = self.feature_sparsity
     if self.GPU:
         identity = torch.diag(torch.ones(XTX.shape[0])).float().cuda()
     else:
         identity = torch.diag(torch.ones(XTX.shape[0])).float()
     penality = torch.exp(self._params["lambda"]) * identity
     if self.elastic_feature_sparsity:
         mu = torch.sigmoid(self._params["feature_elastic_coef"])
         coef = self._inversion_coef_without_sparsity(penality, XTX,
                                                      XTy) * mu
         coef += self._inversion_coef_with_sparsity(penality, XTX,
                                                    XTy) * (1 - mu)
     elif feature_sparsity:
         coef = self._inversion_coef_with_sparsity(penality, XTX, XTy)
     else:
         coef = self._inversion_coef_without_sparsity(penality, XTX, XTy)
     self.coef_ = self._tensor_to_array(coef)
 def _inversion_forward(self, constants, feature_sparsity):
     X, y, XTX, XTy = constants.values()
     if feature_sparsity:
         sparse_vector = torch.diag(self._sparsify("feature"))
         sparse_X = X @ sparse_vector
         sparse_XTX = sparse_vector @ XTX @ sparse_vector
         sparse_XTy = sparse_vector @ XTy
         penality = torch.exp(self._params["lambda"]) * torch.diag(
             torch.ones(XTX.shape[0])).float().cuda()
         inv = torch.inverse(sparse_XTX + penality)
         projection_matrix = sparse_X @ inv
         y_hat = projection_matrix @ sparse_XTy
         return y_hat, permuted_y_hat
     else:
         penality = torch.exp(self._params["lambda"]) * torch.diag(
             torch.ones(XTX.shape[0])).float().cuda()
         inv = torch.inverse(XTX + penality)
         projection_matrix = X @ inv
         y_hat = projection_matrix @ XTy
         return y_hat
 def _inversion_coef_with_sparsity(self, penality, XTX, XTy):
     sparse_vector = torch.diag(self._sparsify("feature"))
     sparse_XTX = sparse_vector @ XTX @ sparse_vector
     sparse_XTy = sparse_vector @ XTy
     inv = torch.inverse(sparse_XTX + penality)
     return sparse_vector @ inv @ sparse_XTy