예제 #1
0
파일: agent.py 프로젝트: zhly0/relaax
    def init(self, exploit=False):
        self.exploit = exploit
        model = da3c_model.AgentModel()
        self.session = session.Session(model)
        if da3c_config.config.use_lstm:
            self.lstm_state = self.initial_lstm_state = self.lstm_zero_state = model.lstm_zero_state

        self.observation = observation.Observation(
            da3c_config.config.input.history)
        self.last_action = None
        self.last_value = None
        self.last_probs = None
        if da3c_config.config.hogwild and not da3c_config.config.use_icm:
            self.queue = queue.Queue(10)
            threading.Thread(target=self.execute_tasks).start()
            self.receive_experience()
        else:
            self.queue = None
        if da3c_config.config.use_icm:
            self.icm_observation = observation.Observation(
                da3c_config.config.input.history)

        if da3c_config.config.use_filter:
            self.filter = utils.ZFilter(da3c_config.config.input.shape)
        self.replay_buffer = DA3CReplayBuffer(self)
        return True
예제 #2
0
    def init(self, exploit=False):
        model = trpo_model.AgentModel()
        self.session = session.Session(model)

        self._episode_timestep = 0  # timestep for current episode (round)
        self._episode_reward = 0  # score accumulator for current episode (round)
        self._stop_training = False  # stop training flag to prevent the training further

        self.data = defaultdict(list)
        self.observation = observation.Observation(
            trpo_config.config.input.history)

        self.policy = network.make_policy_wrapper(self.session,
                                                  self.ps.metrics)

        # counter for global updates at parameter server
        self._n_iter = self.ps.session.call_wait_for_iteration()
        self.session.op_set_weights(
            weights=self.ps.session.policy.op_get_weights())

        if trpo_config.config.use_filter:
            self.obs_filter = network.make_filter(trpo_config.config)
            state = self.ps.session.call_get_filter_state()
            self.obs_filter.rs.set(*state)

        self.server_latency_accumulator = 0  # accumulator for averaging server latency
        self.collecting_time = time.time()  # timer for collecting experience

        return True
예제 #3
0
    def __init__(self, parameter_server, metrics, exploit):
        self.ps = parameter_server
        self.metrics = metrics
        self._exploit = exploit

        self.session = session.Session(dqn_model.AgentModel())
        self.session.op_initialize()

        self.replay_buffer = dqn_utils.ReplayBuffer(
            dqn_config.config.replay_buffer_size, dqn_config.config.alpha)
        self.observation = observation.Observation(
            dqn_config.config.input.history)

        self.last_action = None
        self.local_step = 0
        self.last_target_weights_update = 0
예제 #4
0
    def __init__(self, parameter_server, metrics, exploit, hogwild_update):
        self.exploit = exploit
        self.ps = parameter_server
        self.metrics = metrics
        model = ddpg_model.AgentModel()
        self.session = session.Session(model)

        self.episode = episode.ReplayBuffer(
            ['state', 'action', 'reward', 'terminal', 'next_state'],
            cfg.config.buffer_size,
            seed=cfg.config.exploration.rnd_seed)
        self.episode.begin()
        self.observation = observation.Observation(cfg.config.input.history)
        self.last_action = self.noise_epsilon = None
        self.episode_cnt = self.cur_loop_cnt = 0
        self.exploration_noise = utils.OUNoise(cfg.config.output.action_size,
                                               cfg.config.exploration.ou_mu,
                                               cfg.config.exploration.ou_theta,
                                               cfg.config.exploration.ou_sigma,
                                               cfg.config.exploration.rnd_seed)
        self.max_q = self.step_cnt = 0
        self.agent_weights_id = 0
        self.terminal = False

        if hogwild_update:
            self.queue = queue.Queue(10)
            threading.Thread(target=self.execute_tasks).start()
            self.receive_experience()
        else:
            self.queue = None

        if cfg.config.use_filter:
            shape = cfg.config.input.shape
            if shape == [0]:
                shape = [1]

            self.filter = utils.ZFilter(shape)

        if cfg.config.no_ps:
            self.session.op_initialize()
            self.session.op_init_target_weights()
예제 #5
0
    def __init__(self, parameter_server, exploit, metrics):
        self.exploit = exploit
        self.metrics = metrics
        self.ps = parameter_server
        model = dppo_model.Model(assemble_model=True)
        self.session = session.Session(policy=model.policy,
                                       value_func=model.value_func)
        self.episode = None
        self.steps = 0
        self.observation = observation.Observation(
            dppo_config.config.input.history)

        if dppo_config.config.use_lstm:
            self.initial_lstm_state = self.lstm_state = self.lstm_zero_state = model.lstm_zero_state
            self.mini_batch_lstm_state = None
        self.terminal = False

        self.last_state = None
        self.last_action = None
        self.last_prob = None
        self.last_terminal = None

        self.final_state = None
        self.final_value = None

        self.policy_step = None
        self.value_step = None

        self.mini_batch_size = dppo_config.config.batch_size
        if dppo_config.config.mini_batch is not None:
            self.mini_batch_size = dppo_config.config.mini_batch

        if dppo_config.config.output.continuous:
            self.prob_type = DiagGauss(dppo_config.config.output.action_size)
        else:
            self.prob_type = Categorical(dppo_config.config.output.action_size)

        if dppo_config.config.use_filter:
            self.filter = ZFilter(dppo_config.config.input.shape)