def init(self, exploit=False): self.exploit = exploit model = da3c_model.AgentModel() self.session = session.Session(model) if da3c_config.config.use_lstm: self.lstm_state = self.initial_lstm_state = self.lstm_zero_state = model.lstm_zero_state self.observation = observation.Observation( da3c_config.config.input.history) self.last_action = None self.last_value = None self.last_probs = None if da3c_config.config.hogwild and not da3c_config.config.use_icm: self.queue = queue.Queue(10) threading.Thread(target=self.execute_tasks).start() self.receive_experience() else: self.queue = None if da3c_config.config.use_icm: self.icm_observation = observation.Observation( da3c_config.config.input.history) if da3c_config.config.use_filter: self.filter = utils.ZFilter(da3c_config.config.input.shape) self.replay_buffer = DA3CReplayBuffer(self) return True
def init(self, exploit=False): model = trpo_model.AgentModel() self.session = session.Session(model) self._episode_timestep = 0 # timestep for current episode (round) self._episode_reward = 0 # score accumulator for current episode (round) self._stop_training = False # stop training flag to prevent the training further self.data = defaultdict(list) self.observation = observation.Observation( trpo_config.config.input.history) self.policy = network.make_policy_wrapper(self.session, self.ps.metrics) # counter for global updates at parameter server self._n_iter = self.ps.session.call_wait_for_iteration() self.session.op_set_weights( weights=self.ps.session.policy.op_get_weights()) if trpo_config.config.use_filter: self.obs_filter = network.make_filter(trpo_config.config) state = self.ps.session.call_get_filter_state() self.obs_filter.rs.set(*state) self.server_latency_accumulator = 0 # accumulator for averaging server latency self.collecting_time = time.time() # timer for collecting experience return True
def __init__(self, parameter_server, metrics, exploit): self.ps = parameter_server self.metrics = metrics self._exploit = exploit self.session = session.Session(dqn_model.AgentModel()) self.session.op_initialize() self.replay_buffer = dqn_utils.ReplayBuffer( dqn_config.config.replay_buffer_size, dqn_config.config.alpha) self.observation = observation.Observation( dqn_config.config.input.history) self.last_action = None self.local_step = 0 self.last_target_weights_update = 0
def __init__(self, parameter_server, metrics, exploit, hogwild_update): self.exploit = exploit self.ps = parameter_server self.metrics = metrics model = ddpg_model.AgentModel() self.session = session.Session(model) self.episode = episode.ReplayBuffer( ['state', 'action', 'reward', 'terminal', 'next_state'], cfg.config.buffer_size, seed=cfg.config.exploration.rnd_seed) self.episode.begin() self.observation = observation.Observation(cfg.config.input.history) self.last_action = self.noise_epsilon = None self.episode_cnt = self.cur_loop_cnt = 0 self.exploration_noise = utils.OUNoise(cfg.config.output.action_size, cfg.config.exploration.ou_mu, cfg.config.exploration.ou_theta, cfg.config.exploration.ou_sigma, cfg.config.exploration.rnd_seed) self.max_q = self.step_cnt = 0 self.agent_weights_id = 0 self.terminal = False if hogwild_update: self.queue = queue.Queue(10) threading.Thread(target=self.execute_tasks).start() self.receive_experience() else: self.queue = None if cfg.config.use_filter: shape = cfg.config.input.shape if shape == [0]: shape = [1] self.filter = utils.ZFilter(shape) if cfg.config.no_ps: self.session.op_initialize() self.session.op_init_target_weights()
def __init__(self, parameter_server, exploit, metrics): self.exploit = exploit self.metrics = metrics self.ps = parameter_server model = dppo_model.Model(assemble_model=True) self.session = session.Session(policy=model.policy, value_func=model.value_func) self.episode = None self.steps = 0 self.observation = observation.Observation( dppo_config.config.input.history) if dppo_config.config.use_lstm: self.initial_lstm_state = self.lstm_state = self.lstm_zero_state = model.lstm_zero_state self.mini_batch_lstm_state = None self.terminal = False self.last_state = None self.last_action = None self.last_prob = None self.last_terminal = None self.final_state = None self.final_value = None self.policy_step = None self.value_step = None self.mini_batch_size = dppo_config.config.batch_size if dppo_config.config.mini_batch is not None: self.mini_batch_size = dppo_config.config.mini_batch if dppo_config.config.output.continuous: self.prob_type = DiagGauss(dppo_config.config.output.action_size) else: self.prob_type = Categorical(dppo_config.config.output.action_size) if dppo_config.config.use_filter: self.filter = ZFilter(dppo_config.config.input.shape)