def _fit(self, X, _transform=False): """ Fit the model with X. Parameters ---------- X : dask cuDF input """ n_cols = X.shape[1] data = DistributedDataHandler.create(data=X, client=self.client) self.datatype = data.datatype if "svd_solver" in self.kwargs \ and self.kwargs["svd_solver"] == "tsqr": comms = CommsContext(comms_p2p=True) else: comms = CommsContext(comms_p2p=False) comms.init(workers=data.workers) data.calculate_parts_to_sizes(comms) worker_info = comms.worker_info(comms.worker_addresses) parts_to_sizes, _ = parts_to_ranks(self.client, worker_info, data.gpu_futures) total_rows = data.total_rows models = dict([(data.worker_info[wf[0]]["rank"], self.client.submit(self._create_model, comms.sessionId, self._model_func, self.datatype, **self.kwargs, pure=False, workers=[wf[0]])) for idx, wf in enumerate(data.worker_to_parts.items())]) pca_fit = dict([ (wf[0], self.client.submit(DecompositionSyncFitMixin._func_fit, models[data.worker_info[wf[0]]["rank"]], wf[1], total_rows, n_cols, parts_to_sizes, data.worker_info[wf[0]]["rank"], _transform, pure=False, workers=[wf[0]])) for idx, wf in enumerate(data.worker_to_parts.items()) ]) wait(list(pca_fit.values())) raise_exception_from_futures(list(pca_fit.values())) comms.destroy() self._set_internal_model(list(models.values())[0]) if _transform: out_futures = flatten_grouped_results(self.client, data.gpu_futures, pca_fit) return to_output(out_futures, self.datatype) return self
def _fit(self, X, _transform=False): """ Fit the model with X. Parameters ---------- X : dask cuDF input """ n_cols = X.shape[1] data = DistributedDataHandler.create(data=X, client=self.client) self.datatype = data.datatype comms = CommsContext(comms_p2p=False) comms.init(workers=data.workers) data.calculate_parts_to_sizes(comms) total_rows = data.total_rows models = dict([(data.worker_info[wf[0]]["rank"], self.client.submit(self._create_model, comms.sessionId, self._model_func, self.datatype, **self.kwargs, pure=False, workers=[wf[0]])) for idx, wf in enumerate(data.worker_to_parts.items())]) pca_fit = dict([ (wf[0], self.client.submit( DecompositionSyncFitMixin._func_fit, models[data.worker_info[wf[0]]["rank"]], wf[1], total_rows, n_cols, data.parts_to_sizes[data.worker_info[wf[0]]["rank"]], data.worker_info[wf[0]]["rank"], _transform, pure=False, workers=[wf[0]])) for idx, wf in enumerate(data.worker_to_parts.items()) ]) wait(list(pca_fit.values())) raise_exception_from_futures(list(pca_fit.values())) comms.destroy() self.local_model = list(models.values())[0].result() self.components_ = self.local_model.components_ self.explained_variance_ = self.local_model.explained_variance_ self.explained_variance_ratio_ = \ self.local_model.explained_variance_ratio_ self.singular_values_ = self.local_model.singular_values_ if _transform: out_futures = flatten_grouped_results(self.client, data.gpu_futures, pca_fit) return to_output(out_futures, self.datatype) return self return self