Beispiel #1
0
    def test_pull_hyperparam(self):
        client = ModelPoolAPIs(model_pool_addrs=[
            "localhost:11001:11006", "localhost:11002:11007"
        ])
        key1 = str(uuid.uuid1())
        key2 = str(uuid.uuid1())
        client.push_model(None, "any_hyperparam_object", key1)
        client.push_model(None, "any_hyperparam_object", key2)
        client.push_model(None, "updated_hyperparam_object", key2)

        hyperparam1 = client.pull_attr('hyperparam', key1)
        self.assertEqual(hyperparam1, "any_hyperparam_object")
        hyperparam2 = client.pull_attr('hyperparam', key2)
        self.assertEqual(hyperparam2, "updated_hyperparam_object")
Beispiel #2
0
class InfServer(object):
    def __init__(self,
                 league_mgr_addr,
                 model_pool_addrs,
                 port,
                 ds,
                 batch_size,
                 ob_space,
                 ac_space,
                 policy,
                 outputs=['a'],
                 policy_config={},
                 gpu_id=0,
                 compress=True,
                 batch_worker_num=4,
                 update_model_seconds=60,
                 learner_id=None,
                 log_seconds=60,
                 model_key="",
                 task_attr='model_key',
                 **kwargs):
        self._update_model_seconds = update_model_seconds
        self._log_seconds = log_seconds
        self._learner_id = learner_id
        self._task_attr = task_attr.split('.')
        if model_key:
            # If model_key is given, this indicates the infserver works
            # for a fixed model inference
            self._league_mgr_apis = None
            self.is_rl = False
            self.model_key = model_key
        else:
            # If model_key is absent, this indicates an infserver
            # that performs varying policy inference, and model_key will be
            # assigned by querying league_mgr
            self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr)
            self.is_rl = True
            self.model_key = None
        self.model = None
        self._model_pool_apis = ModelPoolAPIs(model_pool_addrs)
        assert hasattr(policy, 'net_config_cls')
        assert hasattr(policy, 'net_build_fun')
        # bookkeeping
        self.ob_space = ob_space
        self.ob_space = ac_space
        self.batch_size = batch_size
        self._ac_structure = tp_utils.template_structure_from_gym_space(
            ac_space)
        self.outputs = outputs
        # build the net
        policy_config = {} if policy_config is None else policy_config
        policy_config['batch_size'] = batch_size
        use_gpu = (gpu_id >= 0)
        self.data_server = InferDataServer(
            port=port,
            batch_size=batch_size,
            ds=ds,
            batch_worker_num=batch_worker_num,
            use_gpu=use_gpu,
            compress=compress,
        )
        config = tf.ConfigProto(allow_soft_placement=True)
        if use_gpu:
            config.gpu_options.visible_device_list = str(gpu_id)
            config.gpu_options.allow_growth = True
            if 'use_xla' in policy_config and policy_config['use_xla']:
                config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        self._sess = tf.Session(config=config)
        self.nc = policy.net_config_cls(ob_space, ac_space, **policy_config)
        self.net_out = policy.net_build_fun(self.data_server._batch_input,
                                            self.nc,
                                            scope='Inf_server')
        # saving/loading ops
        self.params = self.net_out.vars.all_vars
        self.params_ph = [
            tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.params
        ]
        self.params_assign_ops = [
            p.assign(np_p) for p, np_p in zip(self.params, self.params_ph)
        ]
        # initialize the net params
        tf.global_variables_initializer().run(session=self._sess)
        self.setup_fetches(outputs)
        self.id_and_fetches = [self.data_server._batch_data_id, self.fetches]
        self._update_model()

    def load_model(self, loaded_params):
        self._sess.run(
            self.params_assign_ops[:len(loaded_params)],
            feed_dict={p: v
                       for p, v in zip(self.params_ph, loaded_params)})

    def setup_fetches(self, outputs):
        def split_batch(template, tf_structure):
            split_flatten = zip(*[
                tf.split(t, self.batch_size)
                for t in nest.flatten_up_to(template, tf_structure)
            ])
            return [
                nest.pack_sequence_as(template, flatten)
                for flatten in split_flatten
            ]

        if self.nc.use_self_fed_heads:
            a = nest.map_structure_up_to(self._ac_structure,
                                         lambda head: head.sam,
                                         self.net_out.self_fed_heads)
            neglogp = nest.map_structure_up_to(self._ac_structure,
                                               lambda head: head.neglogp,
                                               self.net_out.self_fed_heads)
            flatparam = nest.map_structure_up_to(self._ac_structure,
                                                 lambda head: head.flatparam,
                                                 self.net_out.self_fed_heads)
            self.all_outputs = {
                'a':
                split_batch(self._ac_structure, a),
                'neglogp':
                split_batch(self._ac_structure, neglogp),
                'flatparam':
                split_batch(self._ac_structure, flatparam),
                'v':
                tf.split(self.net_out.value_head, self.batch_size)
                if self.net_out.value_head is not None else [[]] *
                self.batch_size,
                'state':
                tf.split(self.net_out.S, self.batch_size)
                if self.net_out.S is not None else [[]] * self.batch_size
            }
        else:
            flatparam = nest.map_structure_up_to(self._ac_structure,
                                                 lambda head: head.flatparam,
                                                 self.net_out.outer_fed_heads)
            self.all_outputs = {
                'flatparam':
                split_batch(self._ac_structure, flatparam),
                'state':
                tf.split(self.net_out.S, self.batch_size)
                if self.net_out.S is not None else [[]] * self.batch_size
            }
        if self.nc.use_lstm and 'state' not in outputs:
            outputs.append('state')
        self.fetches = [
            dict(zip(outputs, pred))
            for pred in zip(*[self.all_outputs[o] for o in outputs])
        ]

    def _update_model(self):
        if self.is_rl:
            # if (self.model_key is None or
            #     (self.model is not None and self.model.is_freezed())):
            self._query_task()
        if self._should_update_model(self.model, self.model_key):
            self.model = self._model_pool_apis.pull_model(self.model_key)
            self.load_model(self.model.model)

    def _query_task(self):
        assert self.is_rl, '_query_task can be use in RL!'
        task = self._league_mgr_apis.query_learner_task(self._learner_id)
        while task is None:
            print('Learner has not request task! wait...')
            time.sleep(5)
            task = self._league_mgr_apis.query_learner_task(self._learner_id)
        self.last_model_key = self.model_key
        self.model_key = task
        for attr in self._task_attr:
            self.model_key = getattr(self.model_key, attr)
        return task

    def _should_update_model(self, model, model_key):
        if model is None or model_key != model.key:
            return True
        elif model.is_freezed():
            return False
        else:
            return self._model_pool_apis.pull_attr(
                'updatetime', model_key) > model.updatetime

    def run(self):
        while not self.data_server.ready:
            time.sleep(10)
            print('Waiting at least {} actors to '
                  'connect ...'.format(self.batch_size),
                  flush=True)
        last_update_time = time.time()
        last_log_time = last_update_time
        batch_num = 0
        last_log_batch_num = 0
        pid = os.getpid()
        while True:
            # input is pre-fetched in self.data_server
            data_ids, outputs = self._sess.run(self.id_and_fetches, {})
            self.data_server.response(data_ids, outputs)
            batch_num += 1
            t0 = time.time()
            if t0 > last_update_time + self._update_model_seconds:
                self._update_model()
                last_update_time = t0
            t0 = time.time()
            if t0 > last_log_time + self._log_seconds:
                cost = t0 - last_log_time
                sam_num = self.batch_size * (batch_num - last_log_batch_num)
                print(
                    'Process {} predicts {} samples costs {} seconds, fps {}'.
                    format(pid, sam_num, cost, sam_num / cost),
                    flush=True)
                last_log_batch_num = batch_num
                last_log_time = t0
Beispiel #3
0
class ReplayActor(object):
    def __init__(self,
                 learner_addr,
                 replay_dir,
                 replay_converter_type,
                 policy=None,
                 policy_config=None,
                 model_pool_addrs=None,
                 n_v=1,
                 log_interval=50,
                 step_mul=8,
                 SC2_bin_root='/root/',
                 game_version='3.16.1',
                 unroll_length=32,
                 update_model_freq=32,
                 converter_config=None,
                 agent_cls=None,
                 infserver_addr=None,
                 compress=True,
                 da_rate=-1.,
                 unk_mmr_dft_to=4000):
        self._data_pool_apis = ImLearnerAPIs(learner_addr)
        self._SC2_bin_root = SC2_bin_root
        self._log_interval = log_interval
        self._replay_dir = replay_dir
        self._step_mul = step_mul
        self._game_version = game_version
        self._unroll_length = unroll_length
        self._data_queue = Queue(unroll_length)
        self._push_thread = Thread(target=self._push_data,
                                   args=(self._data_queue, ))
        self._push_thread.daemon = True
        self._push_thread.start()
        self.converter_config = {} if converter_config is None else converter_config
        self.converter_config['game_version'] = game_version
        self.replay_converter_type = replay_converter_type
        self._replay_converter = replay_converter_type(**self.converter_config)
        self._use_policy = policy is not None
        self._update_model_freq = update_model_freq
        self.model_key = 'IL-model'
        self._da_rate = da_rate
        self._unk_mmr_dft_to = unk_mmr_dft_to
        self._system = platform.system()
        ob_space, ac_space = self._replay_converter.space
        if self._use_policy:
            self.model = None
            policy_config = {} if policy_config is None else policy_config
            agent_cls = agent_cls or PPOAgent
            policy_config['batch_size'] = 1
            policy_config['rollout_len'] = 1
            policy_config['use_loss_type'] = 'none'
            self.infserver_addr = infserver_addr
            if infserver_addr is None:
                self._model_pool_apis = ModelPoolAPIs(model_pool_addrs)
                self.agent = agent_cls(policy,
                                       ob_space,
                                       ac_space,
                                       n_v=n_v,
                                       scope_name='self',
                                       policy_config=policy_config)
            else:
                nc = policy.net_config_cls(ob_space, ac_space, **policy_config)
                ds = InfData(ob_space, ac_space,
                             policy_config['use_self_fed_heads'], nc.use_lstm,
                             nc.hs_len)
                self.agent = PGAgentGPU(infserver_addr, ds, nc.hs_len,
                                        compress)
        self.ds = ILData(ob_space, ac_space, self._use_policy,
                         1)  # hs_len does not matter

    def run(self):
        self.replay_task = self._data_pool_apis.request_replay_task()
        while self.replay_task != "":
            game_version = self.replay_task.game_version or self._game_version
            self._adapt_system(game_version)
            if game_version != self._game_version:
                # need re-init replay converter
                self._game_version = game_version
                self.converter_config['game_version'] = game_version
                self._replay_converter = self.replay_converter_type(
                    **self.converter_config)
            game_core_config = ({} if 'game_core_config'
                                not in self.converter_config else
                                self.converter_config['game_core_config'])
            extractor = ReplayExtractor(
                replay_dir=self._replay_dir,
                replay_filename=self.replay_task.replay_name,
                player_id=self.replay_task.player_id,
                replay_converter=self._replay_converter,
                step_mul=self._step_mul,
                version=game_version,
                game_core_config=game_core_config,
                da_rate=self._da_rate,
                unk_mmr_dft_to=self._unk_mmr_dft_to)
            self._steps = 0
            first_frame = True
            if self._use_policy:
                self.agent.reset()
                self._update_agent_model()
            for frame in extractor.extract():
                if self._use_policy:
                    data = (*frame[0], self.agent.state,
                            np.array(first_frame, np.bool))
                    self.agent.update_state(frame[0][0])
                    first_frame = False
                else:
                    data = frame[0]
                data = self.ds.flatten(self.ds.structure(data))
                if self._data_queue.full():
                    logger.log("Actor's queue is full.", level=logger.WARN)
                self._data_queue.put((TensorZipper.compress(data), frame[1]))
                logger.log('successfully put one tuple.', level=logger.DEBUG)
                self._steps += 1
                if self._steps % self._log_interval == 0:
                    logger.log(
                        "%d frames of replay task [%s] sent to learner." %
                        (self._steps, self.replay_task))
                if self._use_policy and self._steps % self._update_model_freq == 0:
                    self._update_agent_model()
            logger.log("Replay task [%s] done. %d frames sent to learner." %
                       (self.replay_task, self._steps))
            self.replay_task = self._data_pool_apis.request_replay_task()
        logger.log("All tasks done.")

    def _adapt_system(self, game_version):
        # TODO(pengsun): any stuff for Darwin, Window?
        if self._system == 'Linux':
            # set the SC2PATH for sc2 binary. See deepmind/pysc2 doc.
            if game_version != '4.7.1' or 'SC2PATH' in os.environ:
                os.environ['SC2PATH'] = os.path.join(self._SC2_bin_root,
                                                     game_version)
        return

    def _update_agent_model(self):
        if self.infserver_addr is not None:
            return
        logger.log('entering _update_agents_model',
                   'steps: {}'.format(self._steps),
                   level=logger.DEBUG + 5)
        if self._should_update_model(self.model, self.model_key):
            model = self._model_pool_apis.pull_model(self.model_key)
            self.agent.load_model(model.model)
            self.model = model

    def _should_update_model(self, model, model_key):
        if model is None:
            return True
        else:
            return self._model_pool_apis.pull_attr(
                'updatetime', model_key) > model.updatetime

    def _push_data(self, data_queue):
        """ push trajectory for the learning agent (id 0). Invoked in a thread """
        while data_queue.empty():
            time.sleep(5)
        logger.log('entering _push_data_to_learner',
                   'steps: {}'.format(self._steps),
                   level=logger.DEBUG + 5)
        while True:
            task = self.replay_task
            frames = []
            weights = []
            for _ in range(self._unroll_length):
                frame, weight = data_queue.get()
                frames.append(frame)
                weights.append(weight)
            self._data_pool_apis.push_data((task, frames, weights))