def fit( self, X_ptrs: List[torch.Tensor], C_ptrs: List[torch.Tensor], y_ptrs: List[torch.Tensor] ): # Checking if the pointers are as expected self._check_ptrs(X_ptrs, C_ptrs, y_ptrs) # Check if each y is a 2-dim or 1-dim tensor, unsqueeze it if it's 1-dim for i, y in enumerate(y_ptrs): if len(y.shape) < 2: y_ptrs[i] = y.unsqueeze(1) self.workers = self._get_workers(X_ptrs) # Computing aggregated pairwise dot products remotely XX_ptrs, Xy_ptrs, yy_ptrs, CX_ptrs, Cy_ptrs = self._remote_dot_products( X_ptrs, C_ptrs, y_ptrs ) # Compute remote QR decompositions R_ptrs = self._remote_qr(C_ptrs) # Secred share tensors between hbc_worker, crypto_provider and a random worker # and compute aggregates. It corresponds to the Combine stage of DASH's algorithm idx = random.randint(0, len(self.workers) - 1) XX_shared = sum(self._share_ptrs(XX_ptrs, idx)) Xy_shared = sum(self._share_ptrs(Xy_ptrs, idx)) yy_shared = sum(self._share_ptrs(yy_ptrs, idx)) CX_shared = sum(self._share_ptrs(CX_ptrs, idx)) Cy_shared = sum(self._share_ptrs(Cy_ptrs, idx)) R_cat_shared = torch.cat(self._share_ptrs(R_ptrs, idx), dim=0) # QR decomposition of R_cat_shared _, R_shared = qr(R_cat_shared, norm_factor=self.total_size ** (1 / 2)) # Compute inverse of upper matrix R_shared_inv = self._inv_upper(R_shared) Qy = R_shared_inv.t() @ Cy_shared QX = R_shared_inv.t() @ CX_shared denominator = XX_shared - (QX ** 2).sum(dim=0) # Need the line below to perform inverse of a number in MPC inv_denominator = ((0 * denominator + 1) / denominator).squeeze() coef_shared = (Xy_shared - QX.t() @ Qy).squeeze() * inv_denominator sigma2_shared = ( (yy_shared - Qy.t() @ Qy).squeeze() * inv_denominator - coef_shared ** 2 ) / self._dgf self.coef = coef_shared.get().float_precision() self.sigma2 = sigma2_shared.get().float_precision() self.se = self.sigma2 ** (1 / 2) self._compute_pvalues()
def _remote_qr(C_ptrs): """ Performs the QR decompositions of permanent covariate matrices remotely. It returns a list with the upper right matrices located in each worker """ R_ptrs = [] for c in C_ptrs: _, r = qr(c) R_ptrs.append(r) return R_ptrs