def execute(cls, ctx, op): if op.merge: return super().execute(ctx, op) from xgboost import train, rabit dtrain = ToDMatrix.get_xgb_dmatrix(ctx[op.dtrain.key]) evals = tuple() if op.evals is not None: eval_dmatrices = [ToDMatrix.get_xgb_dmatrix(ctx[t[0].key]) for t in op.evals] evals = tuple((m, ev[1]) for m, ev in zip(eval_dmatrices, op.evals)) params = op.params params['nthread'] = ctx.get_ncores() or -1 if op.tracker is None: # non distributed local_history = dict() kwargs = dict() if op.kwargs is None else op.kwargs bst = train(params, dtrain, evals=evals, evals_result=local_history, **kwargs) ctx[op.outputs[0].key] = {'booster': pickle.dumps(bst), 'history': local_history} else: # distributed rabit_args = ctx[op.tracker.key] rabit.init(rabit_args) try: local_history = dict() bst = train(params, dtrain, evals=evals, evals_result=local_history, **op.kwargs) ret = {'booster': pickle.dumps(bst), 'history': local_history} if rabit.get_rank() != 0: ret = {} ctx[op.outputs[0].key] = ret finally: rabit.finalize()
def stop(self): """Shutdown parameter server. If current host is master host, also join the background thread that is running the master host. """ self.logger.debug("Shutting down parameter server.") # This is the call that actually shuts down the rabit server; and when # all of the slaves have been shut down then the RabitTracker will close # /shutdown itself. rabit.finalize() if self.is_master_host: self.rabit_context.join()
def execute(cls, ctx, op): if op.merge: return super().execute(ctx, op) from xgboost import train, rabit dtrain = ToDMatrix.get_xgb_dmatrix(ensure_own_data(ctx[op.dtrain.key])) evals = tuple() if op.evals is not None: eval_dmatrices = [ ToDMatrix.get_xgb_dmatrix(ensure_own_data(ctx[t[0].key])) for t in op.evals ] evals = tuple( (m, ev[1]) for m, ev in zip(eval_dmatrices, op.evals)) params = op.params if op.tracker is None: # non distributed local_history = dict() kwargs = dict() if op.kwargs is None else op.kwargs bst = train(params, dtrain, evals=evals, evals_result=local_history, **kwargs) ctx[op.outputs[0].key] = { 'booster': pickle.dumps(bst), 'history': local_history } else: # distributed rabit_args = ctx[op.tracker.key] rabit.init([ arg.tobytes() if isinstance(arg, memoryview) else arg for arg in rabit_args ]) try: local_history = dict() bst = train(params, dtrain, evals=evals, evals_result=local_history, **op.kwargs) ret = {'booster': pickle.dumps(bst), 'history': local_history} if rabit.get_rank() != 0: ret = {} ctx[op.outputs[0].key] = ret finally: rabit.finalize()
def __exit__(self, *args): rabit.finalize()