예제 #1
0
    def _create_worker_monitor(self, worker_address):
        """When a new worker connects to the master, a socket is created to
        send heartbeat signals to the worker.
        """
        worker_heartbeat_socket = self.ctx.socket(zmq.REQ)
        worker_heartbeat_socket.linger = 0
        worker_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        worker_heartbeat_socket.connect("tcp://" + worker_address)

        connected = True
        while connected and self.master_is_alive:
            try:
                worker_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
                worker_status = worker_heartbeat_socket.recv_multipart()
                vacant_cpus = self.job_center.get_vacant_cpu(worker_address)
                total_cpus = self.job_center.get_total_cpu(worker_address)
                self.cluster_monitor.update_worker_status(
                    worker_status, worker_address, vacant_cpus, total_cpus)
                time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
            except zmq.error.Again as e:
                self.job_center.drop_worker(worker_address)
                self.cluster_monitor.drop_worker_status(worker_address)
                logger.warning("\n[Master] Cannot connect to the worker " +
                               "{}. ".format(worker_address) +
                               "Worker_pool will drop this worker.")
                self._print_workers()
                connected = False
            except zmq.error.ZMQError as e:
                break

        worker_heartbeat_socket.close(0)
        logger.warning("Exit worker monitor from master.")
예제 #2
0
파일: worker.py 프로젝트: YuechengLiu/PARL
    def _kill_job(self, job_address):
        """Kill a job process and update worker information"""
        success = self.worker_status.remove_job(job_address)
        if success:
            while True:
                initialized_job = self.job_buffer.get()
                initialized_job.worker_address = self.master_heartbeat_address
                if initialized_job.is_alive:
                    self.worker_status.add_job(initialized_job)
                    if not initialized_job.is_alive:  # make sure that the job is still alive.
                        self.worker_status.remove_job(
                            initialized_job.job_address)
                        continue
                else:
                    logger.warning(
                        "[Worker] a dead job found. The job buffer will not accept this one."
                    )
                if initialized_job.is_alive:
                    break

            self.lock.acquire()
            self.request_master_socket.send_multipart([
                remote_constants.NEW_JOB_TAG,
                cloudpickle.dumps(initialized_job),
                to_byte(job_address)
            ])
            _ = self.request_master_socket.recv_multipart()
            self.lock.release()
예제 #3
0
    def run(self):
        """An infinite loop waiting for messages from the workers and
        clients.

        Master node will receive four types of messages:

        1. A new worker connects to the master node.
        2. A connected worker sending new job address after it kills an old
           job.
        3. A new client connects to the master node.
        4. A connected client submits a job after a remote object is created.
        """
        self.client_socket.linger = 0
        self.client_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)

        while self.master_is_alive:
            try:
                self._receive_message()
                pass
            except zmq.error.Again as e:
                #detect whether `self.master_is_alive` is True periodically
                pass

        logger.warning("[Master] Exit master.")
예제 #4
0
파일: worker.py 프로젝트: YuechengLiu/PARL
    def _create_job_monitor(self, job):
        """Send heartbeat signals to check target's status"""

        # job_heartbeat_socket: sends heartbeat signal to job
        job_heartbeat_socket = self.ctx.socket(zmq.REQ)
        job_heartbeat_socket.linger = 0
        job_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        job_heartbeat_socket.connect("tcp://" + job.worker_heartbeat_address)

        job.is_alive = True
        while job.is_alive and self.master_is_alive and self.worker_is_alive:
            try:
                job_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
                _ = job_heartbeat_socket.recv_multipart()
                time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
            except zmq.error.Again as e:
                job.is_alive = False
                logger.warning(
                    "[Worker] lost connection with the job:{}".format(
                        job.job_address))
                if self.master_is_alive and self.worker_is_alive:
                    self._kill_job(job.job_address)

            except zmq.error.ZMQError as e:
                break

        job_heartbeat_socket.close(0)
예제 #5
0
def get_ip_address():
    """
    get the IP address of the host.
    """

    # Windows
    if _IS_WINDOWS:
        local_ip = socket.gethostbyname(socket.gethostname())
    else:
        # Linux and MacOS
        local_ip = None
        try:
            # First way, tested in Ubuntu and MacOS
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            s.connect(("8.8.8.8", 80))
            local_ip = s.getsockname()[0]
            s.close()
        except:
            # Second way, tested in CentOS
            try:
                local_ip = socket.gethostbyname(socket.gethostname())
            except:
                pass

    if local_ip == None or local_ip == '127.0.0.1' or local_ip == '127.0.1.1':
        logger.warning(
            'get_ip_address failed, please set ip address manually.')
        return None

    return local_ip
예제 #6
0
 def step(self, action, **kwargs):
     obs, r, done, info = self.env.step(action, **kwargs)
     # early stop condition
     if info['target_changed']:
         info['timeout'] = True
         done = True
         logger.warning(
             '[FirstTarget Wrapper] early stop since first target is finished.'
         )
     return obs, r, done, info
예제 #7
0
 def clear(self):
     """Remove all the jobs"""
     self._lock.acquire()
     for job in self.jobs.values():
         try:
             os.kill(job.pid, signal.SIGTERM)
         except OSError:
             logger.warning("job:{} has been killed before".format(job.pid))
         logger.info("[Worker] kills a job:{}".format(job.pid))
     self.jobs = dict()
     self._lock.release()
예제 #8
0
 def call(*args, **kwargs):
     global _writer
     if _writer is None:
         logdir = logger.get_dir()
         if logdir is None:
             logdir = logger.auto_set_dir(action='d')
             logger.warning(
                 "[VisualDL] logdir is None, will save VisualDL files to {}\nView the data using: visualdl --logdir=./{} --host={}"
                 .format(logdir, logdir, get_ip_address()))
         _writer = LogWriter(logdir=logger.get_dir())
     func = getattr(_writer, func_name)
     func(*args, **kwargs)
     _writer.flush()
예제 #9
0
 def call(*args, **kwargs):
     global _writer
     if _writer is None:
         logdir = logger.get_dir()
         if logdir is None:
             logdir = logger.auto_set_dir(action='d')
             logger.warning(
                 "[tensorboard] logdir is None, will save tensorboard files to {}\nView the data using: tensorboard --logdir=./{} --host={}"
                 .format(logdir, logdir, get_ip_address()))
         _writer = SummaryWriter(logdir=logger.get_dir())
     func = getattr(_writer, func_name)
     func(*args, **kwargs)
     _writer.flush()
예제 #10
0
 def request_cpu_resource(self, global_client, max_memory):
     """Try to request cpu resource for 1 second/time for 300 times."""
     cnt = 300
     while cnt > 0:
         job_address = global_client.submit_job(max_memory)
         if job_address is not None:
             return job_address
         if cnt % 30 == 0:
             logger.warning(
                 "No vacant cpu resources at the moment, "
                 "will try {} times later.".format(cnt))
         cnt -= 1
     return None
예제 #11
0
파일: job.py 프로젝트: YuechengLiu/PARL
 def _reply_worker_heartbeat(self, socket):
     """create a socket that replies heartbeat signals from the worker.
     If the worker has exited, the job will exit automatically.
     """
     while True:
         try:
             message = socket.recv_multipart()
             socket.send_multipart([remote_constants.HEARTBEAT_TAG])
         except zmq.error.Again as e:
             logger.warning("[Job] Cannot connect to the worker{}. ".format(
                 self.worker_address) + "Job will quit.")
             break
     socket.close(0)
     os._exit(1)
예제 #12
0
파일: Coach.py 프로젝트: YuechengLiu/PARL
 def loadTrainExamples(self):
     modelFile = os.path.join(self.args.load_folder_file[0],
                              self.args.load_folder_file[1])
     examplesFile = modelFile + ".examples"
     if not os.path.isfile(examplesFile):
         logger.warning(
             "File {} with trainExamples not found!".format(examplesFile))
         r = input("Continue? [y|n]")
         if r != "y":
             sys.exit()
     else:
         logger.info("File with trainExamples found. Loading it...")
         with open(examplesFile, "rb") as f:
             self.trainExamplesHistory = Unpickler(f).load()
         logger.info('Loading done!')
예제 #13
0
def is_gpu_available():
    """ check whether parl can access a GPU

    Returns:
      True if a gpu device can be found.
    """
    ret = get_gpu_count() > 0
    if _HAS_FLUID:
        from paddle import fluid
        if ret is True and not fluid.is_compiled_with_cuda():
            logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \
                But PARL found that Paddle was not complied with CUDA, which may cause issues. \
                Thus PARL will not use GPU.")
            return False
    return ret
예제 #14
0
파일: worker.py 프로젝트: YuechengLiu/PARL
    def _reply_heartbeat(self, target):
        """Worker will kill its jobs when it lost connection with the master.
        """

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_master_port =\
            socket.bind_to_random_port("tcp://*")
        self.master_heartbeat_address = "{}:{}".format(self.worker_ip,
                                                       heartbeat_master_port)

        logger.set_dir(
            os.path.expanduser('~/.parl_data/worker/{}'.format(
                self.master_heartbeat_address.replace(':', '_'))))

        self.heartbeat_socket_initialized.set()
        logger.info("[Worker] Connect to the master node successfully. "
                    "({} CPUs)".format(self.cpu_num))
        while self.master_is_alive and self.worker_is_alive:
            try:
                message = socket.recv_multipart()
                worker_status = self._get_worker_status()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(worker_status[0])),
                    to_byte(str(worker_status[1])),
                    to_byte(worker_status[2]),
                    to_byte(str(worker_status[3]))
                ])
            except zmq.error.Again as e:
                self.master_is_alive = False
            except zmq.error.ContextTerminated as e:
                break
        socket.close(0)
        logger.warning(
            "[Worker] lost connection with the master, will exit reply heartbeat for master."
        )
        self.worker_status.clear()
        self.log_server_proc.kill()
        self.log_server_proc.wait()
        # exit the worker
        self.worker_is_alive = False
        self.exit()
예제 #15
0
    def remove_job(self, killed_job):
        """Rmove a job from internal job pool.

        Args:
            killed_job(str): Job address to be removed.

        Returns: True if removing the job succeeds.
        """
        ret = False
        self._lock.acquire()
        if killed_job in self.jobs:
            pid = self.jobs[killed_job].pid
            self.jobs.pop(killed_job)
            ret = True
            try:
                os.kill(pid, signal.SIGTERM)
            except OSError:
                logger.warning("job:{} has been killed before".format(pid))
            logger.info("[Worker] kills a job:{}".format(killed_job))
        self._lock.release()
        return ret
예제 #16
0
    def _create_sockets(self, master_address):
        """ Each client has 1 sockets as start:

        (1) submit_job_socket: submits jobs to master node.
        """

        # submit_job_socket: submits job to master
        self.submit_job_socket = self.ctx.socket(zmq.REQ)
        self.submit_job_socket.linger = 0
        self.submit_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.submit_job_socket.connect("tcp://{}".format(master_address))
        self.start_time = time.time()
        thread = threading.Thread(target=self._reply_heartbeat)
        thread.setDaemon(True)
        thread.start()
        self.heartbeat_socket_initialized.wait()

        self.client_id = self.reply_master_heartbeat_address.replace(':', '_') + \
                            '_' + str(int(time.time()))

        # check if the master is connected properly
        try:
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_CONNECT_TAG,
                to_byte(self.reply_master_heartbeat_address),
                to_byte(socket.gethostname()),
                to_byte(self.client_id),
            ])
            message = self.submit_job_socket.recv_multipart()
            self.log_monitor_url = to_str(message[1])
        except zmq.error.Again as e:
            logger.warning("[Client] Can not connect to the master, please "
                           "check if master is started and ensure the input "
                           "address {} is correct.".format(master_address))
            self.master_is_alive = False
            raise Exception("Client can not connect to the master, please "
                            "check if master is started and ensure the input "
                            "address {} is correct.".format(master_address))
예제 #17
0
    def step(self, action, **kwargs):
        r = 0.0
        merge_info = {}
        for k in range(self.skip_num):
            self.frame_count += 1
            obs, reward, done, info = self.env.step(action, **kwargs)
            r += reward

            for key in info.keys():
                if 'reward' in key:
                    merge_info[key] = merge_info.get(key, 0.0) + info[key]
                else:
                    merge_info[key] = info[key]

            if info['target_changed']:
                logger.warning(
                    "[FrameSkip Wrapper] early break since target is changed")
                break

            if done:
                break
        merge_info['frame_count'] = self.frame_count
        return obs, r, done, merge_info
예제 #18
0
    def _create_client_monitor(self, client_heartbeat_address):
        """When a new client connects to the master, a socket is created to
        send heartbeat signals to the client.
        """

        client_heartbeat_socket = self.ctx.socket(zmq.REQ)
        client_heartbeat_socket.linger = 0
        client_heartbeat_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        client_heartbeat_socket.connect("tcp://" + client_heartbeat_address)

        client_is_alive = True
        while client_is_alive and self.master_is_alive:
            try:
                client_heartbeat_socket.send_multipart(
                    [remote_constants.HEARTBEAT_TAG])
                client_status = client_heartbeat_socket.recv_multipart()

                self.cluster_monitor.update_client_status(
                    client_status, client_heartbeat_address,
                    self.client_hostname[client_heartbeat_address])

            except zmq.error.Again as e:
                client_is_alive = False
                self.cluster_monitor.drop_client_status(
                    client_heartbeat_address)
                logger.warning("[Master] cannot connect to the client " +
                               "{}. ".format(client_heartbeat_address) +
                               "Please check if it is still alive.")
            time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
        logger.warning("Master exits client monitor for {}.\n".format(
            client_heartbeat_address))
        logger.info(
            "Master connects to {} workers and have {} vacant CPUs.\n".format(
                self.worker_num, self.cpu_num))
        client_heartbeat_socket.close(0)
예제 #19
0
파일: job.py 프로젝트: YuechengLiu/PARL
 def _reply_client_heartbeat(self, socket):
     """Create a socket that replies heartbeat signals from the client.
     If the job losts connection with the client, it will exit too.
     """
     while True:
         try:
             message = socket.recv_multipart()
             stop_job = self._check_used_memory()
             socket.send_multipart([
                 remote_constants.HEARTBEAT_TAG,
                 to_byte(str(stop_job)),
                 to_byte(self.job_address)
             ])
             if stop_job == True:
                 logger.error(
                     "Memory used by this job exceeds {}. This job will exist."
                     .format(self.max_memory))
                 time.sleep(5)
                 socket.close(0)
                 os._exit(1)
         except zmq.error.Again as e:
             logger.warning(
                 "[Job] Cannot connect to the client. This job will exit and inform the worker."
             )
             break
     socket.close(0)
     with self.lock:
         self.kill_job_socket.send_multipart(
             [remote_constants.KILLJOB_TAG,
              to_byte(self.job_address)])
         try:
             _ = self.kill_job_socket.recv_multipart()
         except zmq.error.Again as e:
             pass
     logger.warning("[Job]lost connection with the client, will exit")
     os._exit(1)
예제 #20
0
    def _reply_heartbeat(self):
        """Reply heartbeat signals to the master node."""

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        reply_master_heartbeat_port =\
            socket.bind_to_random_port(addr="tcp://*")
        self.reply_master_heartbeat_address = "{}:{}".format(
            get_ip_address(), reply_master_heartbeat_port)
        self.heartbeat_socket_initialized.set()
        connected = False
        while self.client_is_alive and self.master_is_alive:
            try:
                message = socket.recv_multipart()
                elapsed_time = datetime.timedelta(seconds=int(time.time() -
                                                              self.start_time))
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(self.executable_path),
                    to_byte(str(self.actor_num)),
                    to_byte(str(elapsed_time)),
                    to_byte(str(self.log_monitor_url)),
                ])  # TODO: remove additional information
            except zmq.error.Again as e:
                if connected:
                    logger.warning("[Client] Cannot connect to the master."
                                   "Please check if it is still alive.")
                else:
                    logger.warning(
                        "[Client] Cannot connect to the master."
                        "Please check the firewall between client and master.(e.g., ping the master IP)"
                    )
                self.master_is_alive = False
        socket.close(0)
        logger.warning("Client exit replying heartbeat for master.")
예제 #21
0
파일: job.py 프로젝트: YuechengLiu/PARL
    def single_task(self, obj, reply_socket, job_address):
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.

        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
        """

        while True:
            message = reply_socket.recv_multipart()

            tag = message[0]

            if tag in [
                    remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE,
                    remote_constants.SET_ATTRIBUTE
            ]:
                # if tag == remote_constants.CALL_TAG:
                try:
                    if tag == remote_constants.CALL_TAG:
                        function_name = to_str(message[1])
                        data = message[2]
                        args, kwargs = loads_argument(data)

                        # Redirect stdout to stdout.log temporarily
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, function_name)(*args, **kwargs)

                        ret = dumps_return(ret)

                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG, ret])

                    elif tag == remote_constants.GET_ATTRIBUTE:
                        attribute_name = to_str(message[1])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, attribute_name)
                        ret = dumps_return(ret)
                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG, ret])
                    else:
                        attribute_name = to_str(message[1])
                        attribute_value = loads_return(message[2])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            setattr(obj, attribute_name, attribute_value)
                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG])

                except Exception as e:
                    # reset the job

                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
                        reply_socket.send_multipart([
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
                        reply_socket.send_multipart([
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
                        reply_socket.send_multipart([
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise DeserializeError

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
                        reply_socket.send_multipart([
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
                        break

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
                reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                logger.warning(
                    "An actor exits and this job {} will exit.".format(
                        job_address))
                break
            else:
                logger.error(
                    "The job receives an unknown message: {}".format(message))
                raise NotImplementedError
예제 #22
0
파일: Coach.py 프로젝트: YuechengLiu/PARL
    def learn(self):
        """Each iteration:
        1. Performs numEps episodes of self-play.
        2. Retrains neural network with examples in trainExamplesHistory
           (which has a maximum length of numItersForTrainExamplesHistory).
        3. Evaluates the new neural network with the test dataset.
        4. Pits the new neural network against the old one and accepts it
           only if it wins >= updateThreshold fraction of games.
        """

        # create remote actors to run tasks (self-play/pitting/evaluate_test_dataset) in parallel.
        self._create_remote_actors()

        for iteration in range(1, self.args.numIters + 1):
            logger.info('Starting Iter #{} ...'.format(iteration))

            ####################
            logger.info('Step1: self-play in parallel...')
            iterationTrainExamples = []
            # update weights of remote actors to the latest weights, and ask them to run self-play task
            for signal_queue in self.remote_actors_signal_queues:
                signal_queue.put({"task": "self-play"})
            # wait for all remote actors (a total of self.args.actors_num) to return the self-play results
            for _ in range(self.args.actors_num):
                result = self.remote_actors_return_queue.get()
                iterationTrainExamples.extend(result["self-play"])

            # save the iteration examples to the history
            self.trainExamplesHistory.append(iterationTrainExamples)
            if len(self.trainExamplesHistory
                   ) > self.args.numItersForTrainExamplesHistory:
                logger.warning("Removing the oldest entry in trainExamples.")
                self.trainExamplesHistory.pop(0)
            self.saveTrainExamples(iteration)  # backup history to a file

            ####################
            logger.info('Step2: train neural network...')
            # shuffle examples before training
            trainExamples = []
            for e in self.trainExamplesHistory:
                trainExamples.extend(e)
            shuffle(trainExamples)

            # training new network, keeping a copy of the old one
            self.current_agent.save(
                os.path.join(self.args.checkpoint, 'temp.pth.tar'))
            self.previous_agent.restore(
                os.path.join(self.args.checkpoint, 'temp.pth.tar'))

            self.current_agent.learn(trainExamples)

            ####################
            logger.info('Step3: evaluate test dataset in parallel...')
            cnt = 0
            # update weights of remote actors to the latest weights, and ask them to evaluate assigned test dataset
            for i, data in enumerate(
                    split_group(self.test_dataset,
                                len(self.test_dataset) //
                                self.args.actors_num)):
                self.remote_actors_signal_queues[i].put({
                    "task": "evaluate_test_dataset",
                    "test_dataset": data
                })
                cnt += len(data)
            perfect_moves_cnt, good_moves_cnt = 0, 0
            # wait for all remote actors (a total of self.args.actors_num) to return the evaluating results
            for _ in range(self.args.actors_num):
                (perfect_moves,
                 good_moves) = self.remote_actors_return_queue.get(
                 )["evaluate_test_dataset"]
                perfect_moves_cnt += perfect_moves
                good_moves_cnt += good_moves
            logger.info('perfect moves rate: {}, good moves rate: {}'.format(
                perfect_moves_cnt / cnt, good_moves_cnt / cnt))
            tensorboard.add_scalar('perfect_moves_rate',
                                   perfect_moves_cnt / cnt, iteration)
            tensorboard.add_scalar('good_moves_rate', good_moves_cnt / cnt,
                                   iteration)

            ####################
            logger.info(
                'Step4: pitting against previous generation in parallel...')
            # transfer weights of previous generation and current generation to the remote actors, and ask them to pit.
            for signal_queue in self.remote_actors_signal_queues:
                signal_queue.put({"task": "pitting"})
            previous_wins, current_wins, draws = 0, 0, 0
            for _ in range(self.args.actors_num):
                (pwins_, cwins_,
                 draws_) = self.remote_actors_return_queue.get()["pitting"]
                previous_wins += pwins_
                current_wins += cwins_
                draws += draws_

            logger.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' %
                        (current_wins, previous_wins, draws))
            if previous_wins + current_wins == 0 or float(current_wins) / (
                    previous_wins + current_wins) < self.args.updateThreshold:
                logger.info('REJECTING NEW MODEL')
                self.current_agent.restore(
                    os.path.join(self.args.checkpoint, 'temp.pth.tar'))
            else:
                logger.info('ACCEPTING NEW MODEL')
                self.current_agent.save(
                    os.path.join(self.args.checkpoint, 'best.pth.tar'))
            self.current_agent.save(
                os.path.join(self.args.checkpoint,
                             self.getCheckpointFile(iteration)))
예제 #23
0
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from parl.utils.utils import _HAS_FLUID, _HAS_TORCH
from parl.utils import logger

if _HAS_FLUID:
    from parl.algorithms.fluid import *
elif _HAS_TORCH:
    from parl.algorithms.torch import *
else:
    logger.warning(
        "No deep learning framework was found, but it's ok for parallel computation."
    )
예제 #24
0
    def decorator(cls):
        # we are not going to create a remote actor in job.py
        if 'XPARL' in os.environ and os.environ['XPARL'] == 'True':
            logger.warning(
                "Note: this object will be runnning as a local object")
            return cls

        class RemoteWrapper(object):
            """
            Wrapper for remote class in client side.
            """
            def __init__(self, *args, **kwargs):
                """
                Args:
                    args, kwargs: arguments for the initialization of the unwrapped
                    class.
                """
                self.GLOBAL_CLIENT = get_global_client()

                self.ctx = self.GLOBAL_CLIENT.ctx

                # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
                # finds the master is dead.
                if self.GLOBAL_CLIENT.master_is_alive:
                    job_address = self.request_cpu_resource(
                        self.GLOBAL_CLIENT, max_memory)
                else:
                    raise Exception("Can not submit job to the master. "
                                    "Please check if master is still alive.")

                if job_address is None:
                    raise ResourceError("Cannot submit the job to the master. "
                                        "Please add more CPU resources to the "
                                        "master or try again later.")

                self.internal_lock = threading.Lock()

                # Send actor commands like `init` and `call` to the job.
                self.job_socket = self.ctx.socket(zmq.REQ)
                self.job_socket.linger = 0
                self.job_socket.connect("tcp://{}".format(job_address))
                self.job_address = job_address
                self.job_shutdown = False

                self.send_file(self.job_socket)
                file_name = inspect.getfile(cls)[:-3]
                cls_source = inspect.getsourcelines(cls)
                end_of_file = cls_source[1] + len(cls_source[0])
                class_name = cls.__name__
                self.job_socket.send_multipart([
                    remote_constants.INIT_OBJECT_TAG,
                    cloudpickle.dumps([file_name, class_name, end_of_file]),
                    cloudpickle.dumps([args, kwargs]),
                ])
                message = self.job_socket.recv_multipart()
                tag = message[0]
                if tag == remote_constants.EXCEPTION_TAG:
                    traceback_str = to_str(message[1])
                    self.job_shutdown = True
                    raise RemoteError('__init__', traceback_str)

            def __del__(self):
                """Delete the remote class object and release remote resources."""
                try:
                    self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000)
                except AttributeError:
                    pass
                if not self.job_shutdown:
                    try:
                        self.job_socket.send_multipart(
                            [remote_constants.KILLJOB_TAG])
                        _ = self.job_socket.recv_multipart()
                        self.job_socket.close(0)
                    except AttributeError:
                        pass
                    except zmq.error.ZMQError:
                        pass
                    except TypeError:
                        pass

            def send_file(self, socket):
                try:
                    socket.send_multipart([
                        remote_constants.SEND_FILE_TAG,
                        self.GLOBAL_CLIENT.pyfiles
                    ])
                    _ = socket.recv_multipart()
                except zmq.error.Again as e:
                    logger.error("Send python files failed.")

            def request_cpu_resource(self, global_client, max_memory):
                """Try to request cpu resource for 1 second/time for 300 times."""
                cnt = 300
                while cnt > 0:
                    job_address = global_client.submit_job(max_memory)
                    if job_address is not None:
                        return job_address
                    if cnt % 30 == 0:
                        logger.warning(
                            "No vacant cpu resources at the moment, "
                            "will try {} times later.".format(cnt))
                    cnt -= 1
                return None

            def __setattr__(self, attr, value):
                if attr not in cls().__dict__:
                    super().__setattr__(attr, value)
                else:
                    self.internal_lock.acquire()
                    self.job_socket.send_multipart([
                        remote_constants.SET_ATTRIBUTE,
                        to_byte(attr),
                        dumps_return(value)
                    ])
                    message = self.job_socket.recv_multipart()
                    tag = message[0]
                    self.internal_lock.release()
                    if tag == remote_constants.NORMAL_TAG:
                        pass
                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()
                    return

            def __getattr__(self, attr):
                """Call the function of the unwrapped class."""
                #check if attr is a function or not
                if attr in cls().__dict__:
                    self.internal_lock.acquire()
                    self.job_socket.send_multipart(
                        [remote_constants.GET_ATTRIBUTE,
                         to_byte(attr)])
                    message = self.job_socket.recv_multipart()
                    tag = message[0]

                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])
                        self.internal_lock.release()
                        return ret
                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)

                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)

                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

                def wrapper(*args, **kwargs):
                    if self.job_shutdown:
                        raise RemoteError(
                            attr, "This actor losts connection with the job.")
                    self.internal_lock.acquire()
                    data = dumps_argument(*args, **kwargs)

                    self.job_socket.send_multipart(
                        [remote_constants.CALL_TAG,
                         to_byte(attr), data])

                    message = self.job_socket.recv_multipart()
                    tag = message[0]

                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])

                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)

                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)

                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

                    self.internal_lock.release()
                    return ret

                return wrapper

        RemoteWrapper._original = cls
        return RemoteWrapper