Exemplo n.º 1
0
    def execute(cls, ctx, op):
        if op.merge:
            return super().execute(ctx, op)

        from xgboost import train, rabit

        dtrain = ToDMatrix.get_xgb_dmatrix(ctx[op.dtrain.key])
        evals = tuple()
        if op.evals is not None:
            eval_dmatrices = [ToDMatrix.get_xgb_dmatrix(ctx[t[0].key]) for t in op.evals]
            evals = tuple((m, ev[1]) for m, ev in zip(eval_dmatrices, op.evals))
        params = op.params
        params['nthread'] = ctx.get_ncores() or -1

        if op.tracker is None:
            # non distributed
            local_history = dict()
            kwargs = dict() if op.kwargs is None else op.kwargs
            bst = train(params, dtrain, evals=evals,
                        evals_result=local_history, **kwargs)
            ctx[op.outputs[0].key] = {'booster': pickle.dumps(bst), 'history': local_history}
        else:
            # distributed
            rabit_args = ctx[op.tracker.key]
            rabit.init(rabit_args)
            try:
                local_history = dict()
                bst = train(params, dtrain, evals=evals, evals_result=local_history,
                            **op.kwargs)
                ret = {'booster': pickle.dumps(bst), 'history': local_history}
                if rabit.get_rank() != 0:
                    ret = {}
                ctx[op.outputs[0].key] = ret
            finally:
                rabit.finalize()
Exemplo n.º 2
0
    def execute(cls, ctx, op):
        if op.merge:
            return super().execute(ctx, op)

        from xgboost import train, rabit

        dtrain = ToDMatrix.get_xgb_dmatrix(ensure_own_data(ctx[op.dtrain.key]))
        evals = tuple()
        if op.evals is not None:
            eval_dmatrices = [
                ToDMatrix.get_xgb_dmatrix(ensure_own_data(ctx[t[0].key]))
                for t in op.evals
            ]
            evals = tuple(
                (m, ev[1]) for m, ev in zip(eval_dmatrices, op.evals))
        params = op.params

        if op.tracker is None:
            # non distributed
            local_history = dict()
            kwargs = dict() if op.kwargs is None else op.kwargs
            bst = train(params,
                        dtrain,
                        evals=evals,
                        evals_result=local_history,
                        **kwargs)
            ctx[op.outputs[0].key] = {
                'booster': pickle.dumps(bst),
                'history': local_history
            }
        else:
            # distributed
            rabit_args = ctx[op.tracker.key]
            rabit.init([
                arg.tobytes() if isinstance(arg, memoryview) else arg
                for arg in rabit_args
            ])
            try:
                local_history = dict()
                bst = train(params,
                            dtrain,
                            evals=evals,
                            evals_result=local_history,
                            **op.kwargs)
                ret = {'booster': pickle.dumps(bst), 'history': local_history}
                if rabit.get_rank() != 0:
                    ret = {}
                ctx[op.outputs[0].key] = ret
            finally:
                rabit.finalize()
Exemplo n.º 3
0
 def __enter__(self):
     rabit.init(self.args)
Exemplo n.º 4
0
    def start(self):
        """Start the rabit process.

        If current host is master host, initialize and start the Rabit Tracker in the background. All hosts then connect
        to the master host to set up Rabit rank.

        :return: Initialized RabitHelper, which includes helpful information such as is_master and port
        """
        self.rabit_context = None
        if self.is_master_host:
            self.logger.debug("Master host. Starting Rabit Tracker.")
            # The Rabit Tracker is a Python script that is responsible for
            # allowing each instance of rabit to find its peers and organize
            # itself in to a ring for all-reduce. It supports primitive failure
            # recovery modes.
            #
            # It runs on a master node that each of the individual Rabit instances
            # talk to.
            self.rabit_context = tracker.RabitTracker(hostIP=self.current_host,
                                                      nslave=self.n_workers,
                                                      port=self.port,
                                                      port_end=self.port + 1)

            # Useful logging to ensure that the tracker has started.
            # These are the key-value config pairs that each of the rabit slaves
            # should be initialized with. Since we have deterministically allocated
            # the master host, its port, and the number of workers, we don't need
            # to pass these out-of-band to each slave; but rely on the fact
            # that each slave will calculate the exact same config as the server.
            #
            # TODO: should probably check that these match up what we pass below.
            self.logger.info("Rabit slave environment: {}".format(self.rabit_context.slave_envs()))

            # This actually starts the RabitTracker in a background/daemon thread
            # that will automatically exit when the main process has finished.
            self.rabit_context.start(self.n_workers)

        # Start each parameter server that connects to the master.
        self.logger.debug("Starting parameter server.")

        # Rabit runs as an in-process singleton library that can be configured once.
        # Calling this multiple times will cause a seg-fault (without calling finalize).
        # We pass it the environment variables that match up with the RabitTracker
        # so that this instance can discover its peers (and recover from failure).
        #
        # First we check that the RabitTracker is up and running. Rabit actually
        # breaks (at least on Mac OS X) if the server is not running before it
        # begins to try to connect (its internal retries fail because they reuse
        # the same socket instead of creating a new one).
        #
        # if self.max_connect_attempts is None, this will loop indefinitely.
        attempt = 0
        successful_connection = False
        while (not successful_connection and
               (self.max_connect_attempts is None or attempt < self.max_connect_attempts)):
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                try:
                    self.logger.debug("Checking if RabitTracker is available.")
                    s.connect((self.master_host, self.port))
                    successful_connection = True
                    self.logger.debug("Successfully connected to RabitTracker.")
                except OSError:
                    self.logger.info("Failed to connect to RabitTracker on attempt {}".format(attempt))
                    attempt += 1
                    self.logger.info("Sleeping for {} sec before retrying".format(self.connect_retry_timeout))
                    time.sleep(self.connect_retry_timeout)

        if not successful_connection:
            self.logger.error("Failed to connect to Rabit Tracker after %s attempts", self.max_connect_attempts)
            raise Exception("Failed to connect to Rabit Tracker")
        else:
            self.logger.info("Connected to RabitTracker.")

        rabit.init(['DMLC_NUM_WORKER={}'.format(self.n_workers).encode(),
                    'DMLC_TRACKER_URI={}'.format(self.master_host).encode(),
                    'DMLC_TRACKER_PORT={}'.format(self.port).encode()])

        # We can check that the rabit instance has successfully connected to the
        # server by getting the rank of the server (e.g. its position in the ring).
        # This should be unique for each instance.
        self.logger.debug("Rabit started - Rank {}".format(rabit.get_rank()))
        self.logger.debug("Executing user code")

        # We can now run user-code. Since XGBoost runs in the same process space
        # it will use the same instance of rabit that we have configured. It has
        # a number of checks throughout the learning process to see if it is running
        # in distributed mode by calling rabit APIs. If it is it will do the
        # synchronization automatically.
        #
        # Hence we can now execute any XGBoost specific training code and it
        # will be distributed automatically.
        return RabitHelper(self.is_master_host, self.current_host, self.port)