コード例 #1
0
ファイル: base.py プロジェクト: isVoid/cuml
    def _fit(self, X, _transform=False):
        """
        Fit the model with X.

        Parameters
        ----------
        X : dask cuDF input

        """

        n_cols = X.shape[1]

        data = DistributedDataHandler.create(data=X, client=self.client)
        self.datatype = data.datatype

        if "svd_solver" in self.kwargs \
                and self.kwargs["svd_solver"] == "tsqr":
            comms = CommsContext(comms_p2p=True)
        else:
            comms = CommsContext(comms_p2p=False)

        comms.init(workers=data.workers)

        data.calculate_parts_to_sizes(comms)

        worker_info = comms.worker_info(comms.worker_addresses)
        parts_to_sizes, _ = parts_to_ranks(self.client, worker_info,
                                           data.gpu_futures)

        total_rows = data.total_rows

        models = dict([(data.worker_info[wf[0]]["rank"],
                        self.client.submit(self._create_model,
                                           comms.sessionId,
                                           self._model_func,
                                           self.datatype,
                                           **self.kwargs,
                                           pure=False,
                                           workers=[wf[0]]))
                       for idx, wf in enumerate(data.worker_to_parts.items())])

        pca_fit = dict([
            (wf[0],
             self.client.submit(DecompositionSyncFitMixin._func_fit,
                                models[data.worker_info[wf[0]]["rank"]],
                                wf[1],
                                total_rows,
                                n_cols,
                                parts_to_sizes,
                                data.worker_info[wf[0]]["rank"],
                                _transform,
                                pure=False,
                                workers=[wf[0]]))
            for idx, wf in enumerate(data.worker_to_parts.items())
        ])

        wait(list(pca_fit.values()))
        raise_exception_from_futures(list(pca_fit.values()))

        comms.destroy()

        self._set_internal_model(list(models.values())[0])

        if _transform:
            out_futures = flatten_grouped_results(self.client,
                                                  data.gpu_futures, pca_fit)
            return to_output(out_futures, self.datatype)

        return self
コード例 #2
0
ファイル: base.py プロジェクト: zhangjianting/cuml
    def _fit(self, X, _transform=False):
        """
        Fit the model with X.

        Parameters
        ----------
        X : dask cuDF input

        """

        n_cols = X.shape[1]

        data = DistributedDataHandler.create(data=X, client=self.client)
        self.datatype = data.datatype

        comms = CommsContext(comms_p2p=False)
        comms.init(workers=data.workers)

        data.calculate_parts_to_sizes(comms)

        total_rows = data.total_rows

        models = dict([(data.worker_info[wf[0]]["rank"],
                        self.client.submit(self._create_model,
                                           comms.sessionId,
                                           self._model_func,
                                           self.datatype,
                                           **self.kwargs,
                                           pure=False,
                                           workers=[wf[0]]))
                       for idx, wf in enumerate(data.worker_to_parts.items())])

        pca_fit = dict([
            (wf[0],
             self.client.submit(
                 DecompositionSyncFitMixin._func_fit,
                 models[data.worker_info[wf[0]]["rank"]],
                 wf[1],
                 total_rows,
                 n_cols,
                 data.parts_to_sizes[data.worker_info[wf[0]]["rank"]],
                 data.worker_info[wf[0]]["rank"],
                 _transform,
                 pure=False,
                 workers=[wf[0]]))
            for idx, wf in enumerate(data.worker_to_parts.items())
        ])

        wait(list(pca_fit.values()))
        raise_exception_from_futures(list(pca_fit.values()))

        comms.destroy()

        self.local_model = list(models.values())[0].result()

        self.components_ = self.local_model.components_
        self.explained_variance_ = self.local_model.explained_variance_
        self.explained_variance_ratio_ = \
            self.local_model.explained_variance_ratio_
        self.singular_values_ = self.local_model.singular_values_

        if _transform:
            out_futures = flatten_grouped_results(self.client,
                                                  data.gpu_futures, pca_fit)
            return to_output(out_futures, self.datatype)

        return self

        return self