示例#1
0
    def __init__(self, pipe_c2s, pipe_s2c):
        super(SimulatorMaster, self).__init__()
        assert os.name != 'nt', "Doesn't support windows!"
        self.daemon = True
        self.name = 'SimulatorMaster'

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(20)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(20)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=1000)

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)

        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()

        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket],
                        self.context)
示例#2
0
class SimulatorMaster(threading.Thread):
    """ A base thread to communicate with all StateExchangeSimulatorProcess.
        It should produce action for each simulator, as well as
        defining callbacks when a transition or an episode is finished.
    """
    class ClientState(object):
        def __init__(self):
            self.memory = []    # list of Experience
            self.ident = None

    def __init__(self, pipe_c2s, pipe_s2c):
        super(SimulatorMaster, self).__init__()
        assert os.name != 'nt', "Doesn't support windows!"
        self.daemon = True
        self.name = 'SimulatorMaster'

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(10)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(10)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=100)

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)
        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()
        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)

    def run(self):
        self.clients = defaultdict(self.ClientState)
        try:
            while True:
                msg = loads(self.c2s_socket.recv(copy=False).bytes)
                ident, state, reward, isOver = msg
                client = self.clients[ident]
                if client.ident is None:
                    client.ident = ident
                # maybe check history and warn about dead client?
                self._process_msg(client, state, reward, isOver)
        except zmq.ContextTerminated:
            logger.info("[Simulator] Context was terminated.")

    def __del__(self):
        self.context.destroy(linger=0)
示例#3
0
    def get_recv_thread(self):
        def f():
            msg = loads(self.sim2mgr_socket.recv(copy=False).bytes)
            self.queue.put(msg)

        recv_thread = LoopThread(f, pausable=False)
        # recv_thread.daemon = True
        recv_thread.name = "recv thread"
        return recv_thread
示例#4
0
 def get_simulator_thread(self):
     # spawn a separate thread to run policy, can speed up 1.3x
     def populate_job_func():
         self._populate_job_queue.get()
         with self.trainer.sess.as_default():
             for _ in range(self.update_frequency):
                 self._populate_exp()
     th = LoopThread(populate_job_func, pausable=False)
     th.name = "SimulatorThread"
     return th
示例#5
0
 def _before_train(self):
     model = self.trainer.model
     from tensorpack.callbacks.hooks import CallbackToHook
     from tensorpack.train.base import ReuseSessionCreator
     if len(self._callbacks) > 0:
         self._sess = tf.train.MonitoredSession(
             session_creator=ReuseSessionCreator(self.trainer.sess),
             hooks=[CallbackToHook(cb) for cb in self._callbacks])
     else:
         self._sess = self.trainer.sess
     self.trainer.sess.run(self._op_sync_weights)
     self._thread = LoopThread(self._run_loop)
     self._thread.start()
示例#6
0
    def get_recv_thread(self):
        def f():
            msg = self.sim2exp_socket.recv(copy=False).bytes
            msg = loads(msg)
            print('{}: received msg'.format(self.agent_name))
            try:
                self.queue.put_nowait(msg)
            except Exception:
                logger.info('put queue failed!')
            # send response or not?

        recv_thread = LoopThread(f, pausable=False)
        # recv_thread.daemon = True
        recv_thread.name = "recv thread"
        return recv_thread
示例#7
0
    def _before_train(self):
        for p in self.predictors:
            self.predictors[p].start()

        def f():
            msg = loads(self.sim2coord_socket.recv(copy=False).bytes)
            sim_name = msg[0]
            agent_name = msg[1]

            def cb(outputs):
                try:
                    output = outputs.result()
                except CancelledError:
                    logger.info("{} cancelled.".format(sim_name))
                    return
                print('coordinator sending', sim_name.encode('utf-8'),
                      output[0].shape)
                self.coord2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps(output[0])])

            self.predictors[agent_name].put_task(msg[2:], cb)

        self.recv_thread = ShareSessionThread(LoopThread(f, pausable=False))
        # self.recv_thread.daemon = True
        self.recv_thread.name = 'coordinator recv'
        self.recv_thread.start()
        logger.info('Coordinator started')
示例#8
0
    def __init__(self, pipe_c2s, pipe_s2c):
        super(SimulatorMaster, self).__init__()
        assert os.name != 'nt', "Doesn't support windows!"
        self.daemon = True
        self.name = 'SimulatorMaster'

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(10)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(10)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=100)

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)
        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()
        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
示例#9
0
    def _create_simulator_thread(self, idx):
        # spawn a separate thread to run policy
        def populate_job_func():
            exp = self._populate_job_queue.get()
            self._runners[idx].step(exp)

        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread-{}".format(idx)
        return th
示例#10
0
 def get_simulator_thread(self):
     # spawn a separate thread to run policy
     def populate_job_func():
         self._populate_job_queue.get()
         for _ in range(self.update_frequency):
             self._populate_exp()
     th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
     th.name = "SimulatorThread"
     return th
示例#11
0
    def get_work_thread(self):
        def f():
            msg = self.queue.get()
            sim_name = msg[0]
            if msg[1] == SimulatorManager.MSG_TYPE.LOCK and self.locked_sim is None:
                self.locked_sim = sim_name
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps('lock')])
                time.sleep(0.2)
                return
            if self.locked_sim is not None:
                if sim_name != self.locked_sim:
                    time.sleep(0.2)
                    self.queue.put(msg)
                    return
                elif msg[1] == SimulatorManager.MSG_TYPE.UNLOCK:
                    self.locked_sim = None
                    self.mgr2sim_socket.send_multipart(
                        [sim_name.encode('utf-8'),
                         dumps('unlock')])
                    time.sleep(0.2)
                    return

            self.cxt_switch(sim_name)
            # time.sleep(0.2)
            # print(msg[1])
            if msg[1] == SimulatorManager.MSG_TYPE.SCREEN:
                screen = grab_screen()
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps(screen)])
            elif msg[1] == SimulatorManager.MSG_TYPE.CLICK:
                # print('need to click')
                click(msg[2][0], msg[2][1])
                self.mgr2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps('click')])

        work_thread = LoopThread(f, pausable=False)
        work_thread.name = "work thread"
        return work_thread
示例#12
0
    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            ###############################################################################
            # HITL UPDATE
            # as self.update_frequency = 0 during pretraining, no workers will be initialized.
            ###############################################################################
            #logger.info("update_frequency: {}".format(self.update_frequency))

            for _ in range(int(self.update_frequency)):
                self._populate_exp()

        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th
示例#13
0
    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            i = 0
            # synchronous training
            while i < self.update_frequency:
                if self._populate_exp():
                    i += 1
                    time.sleep(0.1)

            # for _ in range(self.update_frequency):
            #     self._populate_exp()
        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th
示例#14
0
    def setup(self):
        super(MultiGPUTrainer, self).setup()

        self._training_aux_threads = []
        self._training_aux_running = False
        self._training_aux_step_counter = itertools.count()
        from tensorpack.callbacks.hooks import CallbackToHook
        from tensorflow.python.training.monitored_session \
            import _HookedSession as HookedSession

        for tidx, n in enumerate(self._train_ops_aux):
            auxTrainOp = self._train_ops_aux[n]
            logger.info("Create aux train ops {}".format(auxTrainOp._name))
            if len(auxTrainOp._callbacks) > 0:
                auxTrainOp._sess = HookedSession(
                    self.sess,
                    hooks=[CallbackToHook(cb) for cb in auxTrainOp._callbacks])
            else:
                auxTrainOp._sess = self.sess

            def f(op=auxTrainOp):  # avoid late-binding
                try:
                    op._sess.run([op._train_op
                                  ])  # TODO this won't work with StageInput
                except RuntimeError:  # exited
                    pass
                except tf.errors.CancelledError:
                    pass
                except tf.errors.AbortedError:
                    pass
                # next(self._training_aux_step_counter)   # atomic due to GIL

            th = LoopThread(f)
            th.name = "AsyncLoopThread-{}".format(tidx)
            th.pause()
            th.start()
            logger.info("Start aux thread {}".format(auxTrainOp._name))
            self._training_aux_threads.append(th)
示例#15
0
class SimulatorMaster(threading.Thread):
    """ A base thread to communicate with all StateExchangeSimulatorProcess.
        It should produce action for each simulator, as well as
        defining callbacks when a transition or an episode is finished.
    """
    class ClientState(object):

        def __init__(self):
            self.memory = []    # list of Experience

    def __init__(self, pipe_c2s, pipe_s2c):
        super(SimulatorMaster, self).__init__()
        self.daemon = True
        self.name = 'SimulatorMaster'

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(10)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(10)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=100)

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)
        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()
        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)

    def run(self):
        self.clients = defaultdict(self.ClientState)
        try:
            while True:
                msg = loads(self.c2s_socket.recv(copy=False).bytes)
                ident, state, reward, isOver = msg
                # TODO check history and warn about dead client
                client = self.clients[ident]

                # check if reward&isOver is valid
                # in the first message, only state is valid
                if len(client.memory) > 0:
                    client.memory[-1].reward = reward
                    if isOver:
                        self._on_episode_over(ident)
                    else:
                        self._on_datapoint(ident)
                # feed state and return action
                self._on_state(state, ident)
        except zmq.ContextTerminated:
            logger.info("[Simulator] Context was terminated.")

    @abstractmethod
    def _on_state(self, state, ident):
        """response to state sent by ident. Preferrably an async call"""

    @abstractmethod
    def _on_episode_over(self, client):
        """ callback when the client just finished an episode.
            You may want to clear the client's memory in this callback.
        """

    def _on_datapoint(self, client):
        """ callback when the client just finished a transition
        """

    def __del__(self):
        self.context.destroy(linger=0)
示例#16
0
class SimulatorMaster(threading.Thread):
    """ A base thread to communicate with all StateExchangeSimulatorProcess.
        It should produce action for each simulator, as well as
        defining callbacks when a transition or an episode is finished.
    """
    class ClientState(object):
        def __init__(self):
            self.memory = []    # list of Experience

    def __init__(self, pipe_c2s, pipe_s2c):
        super(SimulatorMaster, self).__init__()
        assert os.name != 'nt', "Doesn't support windows!"
        self.daemon = True
        self.name = 'SimulatorMaster'

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(10)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(10)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=100)

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)
        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()
        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)

    def run(self):
        self.clients = defaultdict(self.ClientState)
        try:
            while True:
                msg = loads(self.c2s_socket.recv(copy=False).bytes)
                ident, state, reward, isOver = msg
                # TODO check history and warn about dead client
                client = self.clients[ident]

                # check if reward&isOver is valid
                # in the first message, only state is valid
                if len(client.memory) > 0:
                    client.memory[-1].reward = reward
                    if isOver:
                        self._on_episode_over(ident)
                    else:
                        self._on_datapoint(ident)
                # feed state and return action
                self._on_state(state, ident)
        except zmq.ContextTerminated:
            logger.info("[Simulator] Context was terminated.")

    @abstractmethod
    def _on_state(self, state, ident):
        """response to state sent by ident. Preferrably an async call"""

    @abstractmethod
    def _on_episode_over(self, client):
        """ callback when the client just finished an episode.
            You may want to clear the client's memory in this callback.
        """

    def _on_datapoint(self, client):
        """ callback when the client just finished a transition
        """

    def __del__(self):
        self.context.destroy(linger=0)
示例#17
0
class EvaluatorBase(Callback):
    def __init__(
            self,
            data_io,
            name,
            is_training=False,
            idx=0,
            batch_size=-1,
            sync_step=-1,  # sync only in every epoch
            **kwargs):
        if not name.startswith('evaluate/'): name = 'evaluate/' + name
        self._name = name
        self._idx = idx
        from ..dataflow.tensor_io import TensorIO_AgentPools
        assert (isinstance(data_io, TensorIO_AgentPools))
        self._data_io = data_io  #type: TensorIO_AgentPools
        self._is_training = is_training

        assert (batch_size > 0)
        logger.info(
            "Eval: {} create, batchSize={}, is_train={}, sync_step={}".format(
                name, batch_size, is_training, sync_step))
        assert (batch_size > 0)
        self._batch_size = batch_size
        self._kwargs = kwargs
        self._tensor_io = None
        self._pool_name = name.replace('/', '_')
        self._sync_step = sync_step

    @property
    def name(self):
        return self._name

    @property
    def batch_size(self):
        return self._batch_size

    def _setup_graph(self):
        model = self.trainer.model

        weights_train = model.getWeightsTrain()
        self._op_sync_weights = tf.group(
            *[d.assign(s) for d, s in zip(self._weights, weights_train)])
        self._callbacks = []
        from tensorpack.callbacks.summary import MergeAllSummaries_RunWithOp, MovingAverageSummary
        c_vars = tf.get_collection(self._name + '-ema_op')
        if len(c_vars) > 0:
            self._callbacks.append(MovingAverageSummary(self._name +
                                                        '-ema_op'))
        self._callbacks.append(MergeAllSummaries_RunWithOp(0, self._name))
        for c in self._callbacks:
            c.setup_graph(self.trainer)

        self._tensor_io._setup_graph()
        pass

    def getTensorIO(self, input_desc, **kwargs):
        if not self._tensor_io:
            self.get_input_tensor(input_desc, **kwargs)
        return self._tensor_io

    def get_input_tensor(self, input_desc, **kwargs):
        if self._tensor_io: return self._tensor_io.getInputTensors()

        self._input_desc = input_desc
        kwargs = kwargs.copy()
        kwargs.update(self._kwargs)
        self._tensor_io = self._data_io.getTensorIO(
            self._pool_name + '/pred',
            input_desc,
            queue_size=0,
            is_training=self._is_training,
            allow_no_full_batch=True,
            **kwargs)
        # self._tensor_io = TensorIO_AgentPool(self._name, self._datasets, input_desc, self._batch_size, queue_size = 0, is_training = self._is_training, **kwargs)
        return self._tensor_io.getInputTensors()

    def set_output_tensor(self, *outputs):
        assert (self._tensor_io is not None)
        self._tensor_io.setOutputTensors(*outputs)

    def set_weights(self, weights):
        self._weights = weights

    def _before_train(self):
        model = self.trainer.model
        from tensorpack.callbacks.hooks import CallbackToHook
        from tensorpack.train.base import ReuseSessionCreator
        if len(self._callbacks) > 0:
            self._sess = tf.train.MonitoredSession(
                session_creator=ReuseSessionCreator(self.trainer.sess),
                hooks=[CallbackToHook(cb) for cb in self._callbacks])
        else:
            self._sess = self.trainer.sess
        self.trainer.sess.run(self._op_sync_weights)
        self._thread = LoopThread(self._run_loop)
        self._thread.start()
        # self._tensor_io._before_train()

    def _trigger(self):
        self.trainer.sess.run(self._op_sync_weights)
        pass

    def _before_run(self, ctx):
        if self._sync_step > 0 and self.local_step % self._sync_step == 0:
            return [self._op_sync_weights]
        return None

    def _after_train(self):
        if self._tensor_io:
            self._tensor_io.close()

    def _run_loop(self):
        if self._sess is None: return
        logger.info(
            "Evaluator {} thread start, fetch tensors = {}, batch = {}".format(
                self._name, len(self._tensor_io._output_tensors),
                self._batch_size))
        hooked_sess = self.trainer.hooked_sess
        sess = self._sess
        try:
            tensor_io = self._tensor_io
            while not hooked_sess.should_stop():
                tensor_io._loopStep(sess)
                # logger.info("evaluator loop")
        except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
            pass
        except Exception as e:
            logger.exception("Exception in Evaluator Thread: {}".format(e))
        finally:
            try:
                self._tensor_io.close()
            except Exception:
                pass
            logger.info("Evaluator {} Thread Exited.".format(self._name))
            self._sess = None
        pass