Exemple #1
0
    def get_recv_thread(self):
        def f():
            msg = loads(self.sim2mgr_socket.recv(copy=False).bytes)
            self.queue.put(msg)

        recv_thread = LoopThread(f, pausable=False)
        # recv_thread.daemon = True
        recv_thread.name = "recv thread"
        return recv_thread
Exemple #2
0
 def get_simulator_thread(self):
     # spawn a separate thread to run policy, can speed up 1.3x
     def populate_job_func():
         self._populate_job_queue.get()
         with self.trainer.sess.as_default():
             for _ in range(self.update_frequency):
                 self._populate_exp()
     th = LoopThread(populate_job_func, pausable=False)
     th.name = "SimulatorThread"
     return th
Exemple #3
0
    def get_recv_thread(self):
        def f():
            msg = self.sim2exp_socket.recv(copy=False).bytes
            msg = loads(msg)
            print('{}: received msg'.format(self.agent_name))
            try:
                self.queue.put_nowait(msg)
            except Exception:
                logger.info('put queue failed!')
            # send response or not?

        recv_thread = LoopThread(f, pausable=False)
        # recv_thread.daemon = True
        recv_thread.name = "recv thread"
        return recv_thread
Exemple #4
0
    def get_work_thread(self):
        def f():
            msg = self.queue.get()
            sim_name = msg[0]
            if msg[1] == SimulatorManager.MSG_TYPE.LOCK and self.locked_sim is None:
                self.locked_sim = sim_name
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps('lock')])
                time.sleep(0.2)
                return
            if self.locked_sim is not None:
                if sim_name != self.locked_sim:
                    time.sleep(0.2)
                    self.queue.put(msg)
                    return
                elif msg[1] == SimulatorManager.MSG_TYPE.UNLOCK:
                    self.locked_sim = None
                    self.mgr2sim_socket.send_multipart(
                        [sim_name.encode('utf-8'),
                         dumps('unlock')])
                    time.sleep(0.2)
                    return

            self.cxt_switch(sim_name)
            # time.sleep(0.2)
            # print(msg[1])
            if msg[1] == SimulatorManager.MSG_TYPE.SCREEN:
                screen = grab_screen()
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps(screen)])
            elif msg[1] == SimulatorManager.MSG_TYPE.CLICK:
                # print('need to click')
                click(msg[2][0], msg[2][1])
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps('click')])

        work_thread = LoopThread(f, pausable=False)
        work_thread.name = "work thread"
        return work_thread
Exemple #5
0
    def setup(self):
        super(MultiGPUTrainer, self).setup()

        self._training_aux_threads = []
        self._training_aux_running = False
        self._training_aux_step_counter = itertools.count()
        from tensorpack.callbacks.hooks import CallbackToHook
        from tensorflow.python.training.monitored_session \
            import _HookedSession as HookedSession

        for tidx, n in enumerate(self._train_ops_aux):
            auxTrainOp = self._train_ops_aux[n]
            logger.info("Create aux train ops {}".format(auxTrainOp._name))
            if len(auxTrainOp._callbacks) > 0:
                auxTrainOp._sess = HookedSession(
                    self.sess,
                    hooks=[CallbackToHook(cb) for cb in auxTrainOp._callbacks])
            else:
                auxTrainOp._sess = self.sess

            def f(op=auxTrainOp):  # avoid late-binding
                try:
                    op._sess.run([op._train_op
                                  ])  # TODO this won't work with StageInput
                except RuntimeError:  # exited
                    pass
                except tf.errors.CancelledError:
                    pass
                except tf.errors.AbortedError:
                    pass
                # next(self._training_aux_step_counter)   # atomic due to GIL

            th = LoopThread(f)
            th.name = "AsyncLoopThread-{}".format(tidx)
            th.pause()
            th.start()
            logger.info("Start aux thread {}".format(auxTrainOp._name))
            self._training_aux_threads.append(th)