class BasicManager(object): def __init__(self, session, device, id, action_shape, state_shape, concat_size=0, global_network=None, training=True): self.training = training self.session = session self.id = id self.device = device self.state_shape = state_shape self.set_model_size() if self.training: self.global_network = global_network # Gradient optimizer and clip range if not self.is_global_network(): self.clip = self.global_network.clip else: self.initialize_gradient_optimizer() # Build agents self.model_list = [] self.build_agents(state_shape=state_shape, action_shape=action_shape, concat_size=concat_size) # Build experience buffer if flags.replay_ratio > 0: if flags.prioritized_replay: self.experience_buffer = PrioritizedBuffer(size=flags.replay_buffer_size) # self.beta_schedule = LinearSchedule(flags.max_time_step, initial_p=0.4, final_p=1.0) else: self.experience_buffer = Buffer(size=flags.replay_buffer_size) if flags.predict_reward: self.reward_prediction_buffer = Buffer(size=flags.reward_prediction_buffer_size) # Bind optimizer to global if not self.is_global_network(): self.bind_to_global(self.global_network) # Count based exploration if flags.use_count_based_exploration_reward: self.projection = None self.projection_dataset = [] if flags.print_loss: self._loss_list = [{} for _ in range(self.model_size)] else: self.global_network = None self.model_list = global_network.model_list # Statistics self._model_usage_list = deque() def is_global_network(self): return self.global_network is None def set_model_size(self): self.model_size = 1 self.agents_set = set([0]) def build_agents(self, state_shape, action_shape, concat_size): agent=eval('{}_Network'.format(flags.network))( session=self.session, id='{0}_{1}'.format(self.id, 0), device=self.device, state_shape=state_shape, action_shape=action_shape, concat_size=concat_size, clip=self.clip[0], predict_reward=flags.predict_reward, training = self.training ) self.model_list.append(agent) def sync(self): # assert not self.is_global_network(), 'you are trying to sync the global network with itself' for i in range(self.model_size): agent = self.model_list[i] sync = self.sync_list[i] agent.sync(sync) def initialize_gradient_optimizer(self): self.global_step = [] self.learning_rate = [] self.clip = [] self.gradient_optimizer = [] for i in range(self.model_size): # global step self.global_step.append( tf.Variable(0, trainable=False) ) # learning rate self.learning_rate.append( eval('tf.train.'+flags.alpha_annealing_function)(learning_rate=flags.alpha, global_step=self.global_step[i], decay_steps=flags.alpha_decay_steps, decay_rate=flags.alpha_decay_rate) if flags.alpha_decay else flags.alpha ) # clip self.clip.append( eval('tf.train.'+flags.clip_annealing_function)(learning_rate=flags.clip, global_step=self.global_step[i], decay_steps=flags.clip_decay_steps, decay_rate=flags.clip_decay_rate) if flags.clip_decay else flags.clip ) # gradient optimizer self.gradient_optimizer.append( eval('tf.train.'+flags.optimizer+'Optimizer')(learning_rate=self.learning_rate[i], use_locking=True) ) def bind_to_global(self, global_network): self.sync_list = [] for i in range(self.model_size): local_agent = self.get_model(i) global_agent = global_network.get_model(i) local_agent.minimize_local_loss(optimizer=global_network.gradient_optimizer[i], global_step=global_network.global_step[i], global_var_list=global_agent.get_shared_keys()) self.sync_list.append(local_agent.bind_sync(global_agent)) # for syncing local network with global one def get_model(self, id): return self.model_list[id] def get_statistics(self): stats = {} if self.training: # build loss statistics if flags.print_loss: for i in range(self.model_size): for key, value in self._loss_list[i].items(): stats['loss_{}{}_avg'.format(key,i)] = np.average(value) # build models usage statistics if self.model_size > 1: total_usage = 0 usage_matrix = {} for u in self._model_usage_list: if not (u in usage_matrix): usage_matrix[u] = 0 usage_matrix[u] += 1 total_usage += 1 for i in range(self.model_size): stats['model_{}'.format(i)] = 0 for key, value in usage_matrix.items(): stats['model_{}'.format(key)] = value/total_usage if total_usage != 0 else 0 return stats def add_to_statistics(self, id): self._model_usage_list.append(id) if len(self._model_usage_list) > flags.match_count_for_evaluation: self._model_usage_list.popleft() # remove old statistics def get_shared_keys(self): vars = [] for agent in self.model_list: vars += agent.get_shared_keys() return vars def reset(self): self.step = 0 self.agent_id = 0 # Internal states self.internal_states = None if flags.share_internal_state else [None]*self.model_size if self.training: # Count based exploration if flags.use_count_based_exploration_reward: self.hash_state_table = {} def initialize_new_batch(self): self.batch = ExperienceBatch(self.model_size) def estimate_value(self, agent_id, states, concats=None, internal_state=None): return self.get_model(agent_id).predict_value(states=states, concats=concats, internal_state=internal_state) def act(self, act_function, state, concat=None): agent_id = self.agent_id agent = self.get_model(agent_id) internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id] action_batch, value_batch, policy_batch, new_internal_state = agent.predict_action(states=[state], concats=[concat], internal_state=internal_state) if flags.share_internal_state: self.internal_states = new_internal_state else: self.internal_states[agent_id] = new_internal_state action, value, policy = action_batch[0], value_batch[0], policy_batch[0] new_state, extrinsic_reward, terminal = act_function(action) if self.training: if flags.clip_reward: extrinsic_reward = np.clip(extrinsic_reward, flags.min_reward, flags.max_reward) intrinsic_reward = 0 if self.training: if flags.use_count_based_exploration_reward: # intrinsic reward intrinsic_reward += self.get_count_based_exploration_reward(new_state) total_reward = np.array([extrinsic_reward, intrinsic_reward], dtype=np.float32) if self.training: self.batch.add_action(agent_id=agent_id, state=state, concat=concat, action=action, policy=policy, reward=total_reward, value=value, internal_state=internal_state) # update step at the end of the action self.step += 1 # return result return new_state, value, action, total_reward, terminal, policy def get_count_based_exploration_reward(self, new_state): if len(self.projection_dataset) < flags.projection_dataset_size: self.projection_dataset.append(new_state.flatten()) if len(self.projection_dataset) == flags.projection_dataset_size: if self.projection is None: self.projection = SparseRandomProjection(n_components=flags.exploration_hash_size if flags.exploration_hash_size > 0 else 'auto') # http://scikit-learn.org/stable/modules/random_projection.html self.projection.fit(self.projection_dataset) self.projection_dataset = [] # reset if self.projection is not None: state_projection = self.projection.transform([new_state.flatten()])[0] # project to smaller dimension state_hash = ''.join('1' if x > 0 else '0' for x in state_projection) # build binary locality-sensitive hash if state_hash not in self.hash_state_table: self.hash_state_table[state_hash] = 1 else: self.hash_state_table[state_hash] += 1 exploration_bonus = 2/np.sqrt(self.hash_state_table[state_hash]) - 1 # in [-1,1] return flags.positive_exploration_coefficient*exploration_bonus if exploration_bonus > 0 else flags.negative_exploration_coefficient*exploration_bonus return 0 def compute_discounted_cumulative_reward(self, batch): last_value = batch.bootstrap['value'] if 'value' in batch.bootstrap else 0. batch.compute_discounted_cumulative_reward(agents=self.agents_set, last_value=last_value, gamma=flags.gamma, lambd=flags.lambd) return batch def train(self, batch): # assert self.global_network is not None, 'Cannot train the global network.' states = batch.states internal_states = batch.internal_states concats = batch.concats actions = batch.actions policies = batch.policies values = batch.values rewards = batch.rewards dcr = batch.discounted_cumulative_rewards gae = batch.generalized_advantage_estimators batch_error = [] for i in range(self.model_size): batch_size = len(states[i]) if batch_size > 0: model = self.get_model(i) # reward prediction if model.predict_reward: sampled_batch = self.reward_prediction_buffer.sample() reward_prediction_states, reward_prediction_target = self.get_reward_prediction_tuple(sampled_batch) else: reward_prediction_states = None reward_prediction_target = None # train error, train_info = model.train( states=states[i], concats=concats[i], actions=actions[i], values=values[i], policies=policies[i], rewards=rewards[i], discounted_cumulative_rewards=dcr[i], generalized_advantage_estimators=gae[i], reward_prediction_states=reward_prediction_states, reward_prediction_target=reward_prediction_target, internal_state=internal_states[i][0] ) batch_error.append(error) # loss statistics if flags.print_loss: for key, value in train_info.items(): if key not in self._loss_list[i]: self._loss_list[i][key] = deque() self._loss_list[i][key].append(value) if len(self._loss_list[i][key]) > flags.match_count_for_evaluation: # remove old statistics self._loss_list[i][key].popleft() return batch_error def bootstrap(self, state, concat=None): agent_id = self.agent_id internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id] value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state) bootstrap = self.batch.bootstrap bootstrap['internal_state'] = internal_state bootstrap['agent_id'] = agent_id bootstrap['state'] = state bootstrap['concat'] = concat bootstrap['value'] = value_batch[0] def replay_value(self, batch): # replay values # replay values for (agent_id,pos) in batch.step_generator(): concat, state, internal_state = batch.get_action(['concats','states','internal_states'], agent_id, pos) value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state) batch.set_action({'values':value_batch[0]}, agent_id, pos) if 'value' in batch.bootstrap: bootstrap = batch.bootstrap agent_id = bootstrap['agent_id'] value_batch, _ = self.estimate_value(agent_id=agent_id, states=[bootstrap['state']], concats=[bootstrap['concat']], internal_state=bootstrap['internal_state']) bootstrap['value'] = value_batch[0] return self.compute_discounted_cumulative_reward(batch) def add_to_reward_prediction_buffer(self, batch): batch_size = batch.get_size(self.agents_set) if batch_size < 2: return batch_extrinsic_reward = batch.get_cumulative_reward(self.agents_set)[0] self.reward_prediction_buffer.put(batch=batch, type_id=1 if batch_extrinsic_reward != 0 else 0) # process batch only after sampling, for better perfomance def get_reward_prediction_tuple(self, batch): flat_states = [batch.get_action('states', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)] flat_rewards = [batch.get_action('rewards', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)] states_count = len(flat_states) length = min(3, states_count-1) start_idx = np.random.randint(states_count-length) if states_count > length else 0 reward_prediction_states = [flat_states[start_idx+i] for i in range(length)] reward_prediction_target = np.zeros((1,3)) target_reward = flat_rewards[start_idx+length][0] # use only extrinsic rewards if target_reward == 0: reward_prediction_target[0][0] = 1.0 # zero elif target_reward > 0: reward_prediction_target[0][1] = 1.0 # positive else: reward_prediction_target[0][2] = 1.0 # negative return reward_prediction_states, reward_prediction_target def add_to_replay_buffer(self, batch, batch_error): batch_size = batch.get_size(self.agents_set) if batch_size < 1: return batch_reward = batch.get_cumulative_reward(self.agents_set) batch_extrinsic_reward = batch_reward[0] batch_intrinsic_reward = batch_reward[1] batch_tot_reward = batch_extrinsic_reward + batch_intrinsic_reward if batch_tot_reward == 0 and flags.save_only_batches_with_reward: return if flags.replay_using_default_internal_state: batch.reset_internal_states() type_id = (1 if batch_intrinsic_reward > 0 else (2 if batch_extrinsic_reward > 0 else 0)) if flags.prioritized_replay: self.experience_buffer.put(batch=batch, priority=batch_tot_reward, type_id=type_id) else: self.experience_buffer.put(batch=batch, type_id=type_id) def replay_experience(self): if not self.experience_buffer.has_atleast(flags.replay_start): return n = np.random.poisson(flags.replay_ratio) for _ in range(n): old_batch = self.experience_buffer.sample() self.train(self.replay_value(old_batch) if flags.replay_value else old_batch) def process_batch(self, global_step): batch = self.compute_discounted_cumulative_reward(self.batch) # reward prediction if flags.predict_reward: self.add_to_reward_prediction_buffer(batch) # do it before training, this way there will be at least one batch in the reward_prediction_buffer if self.reward_prediction_buffer.is_empty(): return # cannot train without reward prediction, wait until reward_prediction_buffer is not empty # train batch_error = self.train(batch) # experience replay (after training!) if flags.replay_ratio > 0 and global_step > flags.replay_step: self.replay_experience() self.add_to_replay_buffer(batch, batch_error)
class KMeansPartitioner(BasicManager): def set_model_size(self): self.model_size = flags.partition_count # manager output size if self.model_size < 2: self.model_size = 2 self.agents_set = set(range(self.model_size)) def build_agents(self, state_shape, action_shape, concat_size): # partitioner if self.is_global_network(): self.buffer = Buffer(size=flags.partitioner_dataset_size) self.partitioner = KMeans(n_clusters=self.model_size) self.partitioner_trained = False # agents self.model_list = [] for i in range(self.model_size): agent = eval(flags.network + "_Network")( id="{0}_{1}".format(self.id, i), device=self.device, session=self.session, state_shape=state_shape, action_shape=action_shape, concat_size=concat_size, clip=self.clip[i], predict_reward=flags.predict_reward, training=self.training) self.model_list.append(agent) # bind partition nets to training net if self.is_global_network(): self.bind_to_training_net() self.lock = threading.Lock() def bind_to_training_net(self): self.sync_list = [] training_net = self.get_model(0) for i in range(1, self.model_size): partition_net = self.get_model(i) self.sync_list.append(partition_net.bind_sync( training_net)) # for syncing local network with global one def sync_with_training_net(self): for i in range(1, self.model_size): self.model_list[i].sync(self.sync_list[i - 1]) def get_state_partition(self, state): id = self.partitioner.predict([state.flatten()])[0] # print(self.id, " ", id) self.add_to_statistics(id) return id def query_partitioner(self, step): return self.partitioner_trained and step % flags.partitioner_granularity == 0 def act(self, act_function, state, concat=None): if self.query_partitioner(self.step): self.agent_id = self.get_state_partition(state) return super().act(act_function, state, concat) def populate_partitioner(self, states): # assert self.is_global_network(), 'only global network can populate partitioner' with self.lock: if not self.partitioner_trained: for i in range(0, len(states), flags.partitioner_granularity): state = states[i] self.buffer.put(batch=state.flatten()) if self.buffer.is_full(): print("Buffer is full, starting partitioner training") self.partitioner.fit(self.buffer.get_batches()) print("Partitioner trained") self.partitioner_trained = True print("Syncing with training net") self.sync_with_training_net() print("Cleaning buffer") self.buffer.clean() def bootstrap(self, state, concat=None): if self.query_partitioner(self.step): self.agent_id = self.get_state_partition(state) super().bootstrap(state, concat) # populate partitioner training set if not self.partitioner_trained and not self.is_global_network(): self.global_network.populate_partitioner( states=self.batch.states[self.agent_id] ) # if the partitioner is not trained, al the states are associated to the current agent self.partitioner_trained = self.global_network.partitioner_trained if self.partitioner_trained: self.partitioner = copy.deepcopy( self.global_network.partitioner)