コード例 #1
0
ファイル: agent.py プロジェクト: zhly0/relaax
    def init(self, exploit=False):
        self.exploit = exploit
        model = da3c_model.AgentModel()
        self.session = session.Session(model)
        if da3c_config.config.use_lstm:
            self.lstm_state = self.initial_lstm_state = self.lstm_zero_state = model.lstm_zero_state

        self.observation = observation.Observation(
            da3c_config.config.input.history)
        self.last_action = None
        self.last_value = None
        self.last_probs = None
        if da3c_config.config.hogwild and not da3c_config.config.use_icm:
            self.queue = queue.Queue(10)
            threading.Thread(target=self.execute_tasks).start()
            self.receive_experience()
        else:
            self.queue = None
        if da3c_config.config.use_icm:
            self.icm_observation = observation.Observation(
                da3c_config.config.input.history)

        if da3c_config.config.use_filter:
            self.filter = utils.ZFilter(da3c_config.config.input.shape)
        self.replay_buffer = DA3CReplayBuffer(self)
        return True
コード例 #2
0
ファイル: pg_batch.py プロジェクト: yhyu13/relaax
 def __init__(self, parameter_server, exploit):
     self.exploit = exploit
     self.ps = parameter_server
     self.session = session.Session(pg_model.PolicyModel())
     self.reset()
     self.last_state = None
     self.last_action = None
コード例 #3
0
ファイル: agent.py プロジェクト: deeplearninc/relaax
    def init(self, exploit=False):
        model = trpo_model.AgentModel()
        self.session = session.Session(model)

        self._episode_timestep = 0  # timestep for current episode (round)
        self._episode_reward = 0  # score accumulator for current episode (round)
        self._stop_training = False  # stop training flag to prevent the training further

        self.data = defaultdict(list)
        self.observation = observation.Observation(
            trpo_config.config.input.history)

        self.policy = network.make_policy_wrapper(self.session,
                                                  self.ps.metrics)

        # counter for global updates at parameter server
        self._n_iter = self.ps.session.call_wait_for_iteration()
        self.session.op_set_weights(
            weights=self.ps.session.policy.op_get_weights())

        if trpo_config.config.use_filter:
            self.obs_filter = network.make_filter(trpo_config.config)
            state = self.ps.session.call_get_filter_state()
            self.obs_filter.rs.set(*state)

        self.server_latency_accumulator = 0  # accumulator for averaging server latency
        self.collecting_time = time.time()  # timer for collecting experience

        return True
コード例 #4
0
ファイル: parameter_server.py プロジェクト: zhly0/relaax
 def init_session(self):
     self.session = session.Session(trpo_model.SharedParameters())
     if trpo_config.config. async:
         self.session.ps = PsAsync(self.session, self.metrics, self)
     else:
         self.session.ps = Ps(self.session, self.metrics, self)
     self.session.op_initialize()
コード例 #5
0
    def init_session(self):
        sg_model = dppo_model.Model()
        policy_shared = dppo_model.SharedWeights(sg_model.actor.weights)
        value_func_shared = dppo_model.SharedWeights(sg_model.critic.weights)

        self.session = session.Session(policy=policy_shared,
                                       value_func=value_func_shared)

        self.session.policy.op_initialize()
        self.session.value_func.op_initialize()
コード例 #6
0
    def test_niter(self):
        n_iter = trpo_graph.NIter()
        init = tf.global_variables_initializer()
        s = session.Session(n_iter)
        s.session.run(init)

        assert s.op_n_iter() == 0
        s.op_next_iter()
        assert s.op_n_iter() == 1
        s.op_turn_collect_on()
        assert s.op_n_iter() == -1
        s.op_next_iter()
        assert s.op_n_iter() == -1
        s.op_turn_collect_off()
        assert s.op_n_iter() == 2
コード例 #7
0
    def __init__(self, parameter_server, metrics, exploit):
        self.ps = parameter_server
        self.metrics = metrics
        self._exploit = exploit

        self.session = session.Session(dqn_model.AgentModel())
        self.session.op_initialize()

        self.replay_buffer = dqn_utils.ReplayBuffer(
            dqn_config.config.replay_buffer_size, dqn_config.config.alpha)
        self.observation = observation.Observation(
            dqn_config.config.input.history)

        self.last_action = None
        self.local_step = 0
        self.last_target_weights_update = 0
コード例 #8
0
    def __init__(self, parameter_server, metrics, exploit, hogwild_update):
        self.exploit = exploit
        self.ps = parameter_server
        self.metrics = metrics
        model = ddpg_model.AgentModel()
        self.session = session.Session(model)

        self.episode = episode.ReplayBuffer(
            ['state', 'action', 'reward', 'terminal', 'next_state'],
            cfg.config.buffer_size,
            seed=cfg.config.exploration.rnd_seed)
        self.episode.begin()
        self.observation = observation.Observation(cfg.config.input.history)
        self.last_action = self.noise_epsilon = None
        self.episode_cnt = self.cur_loop_cnt = 0
        self.exploration_noise = utils.OUNoise(cfg.config.output.action_size,
                                               cfg.config.exploration.ou_mu,
                                               cfg.config.exploration.ou_theta,
                                               cfg.config.exploration.ou_sigma,
                                               cfg.config.exploration.rnd_seed)
        self.max_q = self.step_cnt = 0
        self.agent_weights_id = 0
        self.terminal = False

        if hogwild_update:
            self.queue = queue.Queue(10)
            threading.Thread(target=self.execute_tasks).start()
            self.receive_experience()
        else:
            self.queue = None

        if cfg.config.use_filter:
            shape = cfg.config.input.shape
            if shape == [0]:
                shape = [1]

            self.filter = utils.ZFilter(shape)

        if cfg.config.no_ps:
            self.session.op_initialize()
            self.session.op_init_target_weights()
コード例 #9
0
ファイル: dppo_batch.py プロジェクト: deeplearninc/relaax
    def __init__(self, parameter_server, exploit, metrics):
        self.exploit = exploit
        self.metrics = metrics
        self.ps = parameter_server
        model = dppo_model.Model(assemble_model=True)
        self.session = session.Session(policy=model.policy,
                                       value_func=model.value_func)
        self.episode = None
        self.steps = 0
        self.observation = observation.Observation(
            dppo_config.config.input.history)

        if dppo_config.config.use_lstm:
            self.initial_lstm_state = self.lstm_state = self.lstm_zero_state = model.lstm_zero_state
            self.mini_batch_lstm_state = None
        self.terminal = False

        self.last_state = None
        self.last_action = None
        self.last_prob = None
        self.last_terminal = None

        self.final_state = None
        self.final_value = None

        self.policy_step = None
        self.value_step = None

        self.mini_batch_size = dppo_config.config.batch_size
        if dppo_config.config.mini_batch is not None:
            self.mini_batch_size = dppo_config.config.mini_batch

        if dppo_config.config.output.continuous:
            self.prob_type = DiagGauss(dppo_config.config.output.action_size)
        else:
            self.prob_type = Categorical(dppo_config.config.output.action_size)

        if dppo_config.config.use_filter:
            self.filter = ZFilter(dppo_config.config.input.shape)
コード例 #10
0
    def __init__(self, parameter_server, exploit):
        self.exploit = exploit
        self.ps = parameter_server
        self.session = session.Session(lm=LocalManagerNetwork(),
                                       lw=LocalWorkerNetwork())
        self.reset()

        self.goal_buffer = RingBuffer2D(element_size=cfg.d,
                                        buffer_size=cfg.c * 2)
        self.st_buffer = RingBuffer2D(element_size=cfg.d,
                                      buffer_size=cfg.c * 2)
        self.states = []
        self.last_action = None
        self.last_value = None
        self.first = cfg.c  # =batch_size

        self.worker_start_lstm_state = None
        self.manager_start_lstm_state = None
        self.cur_c = 0

        # addition last fields
        self.last_zt_inp = None
        self.last_m_value = None
        self.last_goal = None
コード例 #11
0
 def init_session(self):
     self.session = session.Session(da3c_model.SharedParameters())
     self.session.op_initialize()
コード例 #12
0
 def init_session(self):
     self.session = session.Session(gm=GlobalManagerNetwork(),
                                    gw=GlobalWorkerNetwork())
     self.session.op_initialize()
コード例 #13
0
 def init_session(self):
     self.session = session.Session(ddpg_model.SharedParameters())
     self.session.op_initialize()
     self.session.op_init_target_weights()
コード例 #14
0
 def init(self, exploit=False):
     self.session = session.Session(pg_model.PolicyModel())
     self.trainer = Trainer(self.session, exploit, self.ps, self.metrics)
     return True
コード例 #15
0
ファイル: parameter_server.py プロジェクト: zhly0/relaax
 def init_session(self):
     self.session = session.Session(dqn_model.GlobalServer())
     self.session.op_initialize()
コード例 #16
0
ファイル: agent.py プロジェクト: zhly0/relaax
 def __init__(self, parameter_server, metrics):
     self.ps = parameter_server
     self.metrics = metrics
     self.session = session.Session(model.Model())