def build_actor_nets( input_size: int, num_actions: int, config: Namespace) -> Tuple[NetworkBase, NetworkBase, int]: """ Build two actor networks that are connected into series """ hidden_sizes = config.hidden_sizes assert len(hidden_sizes) >= 4 and len(hidden_sizes) % 2 == 0, \ 'hidden sizes should be at least 4 and should be divisible by two (actor_a and actor_b)' hidden_a = hidden_sizes[:len(hidden_sizes) // 2 - 1] thought_size = hidden_sizes[len(hidden_sizes) // 2 - 1] hidden_b = hidden_sizes[len(hidden_sizes) // 2 + 1:] network = locate(config.network) actor_a = network( input_size=input_size, output_size=thought_size, hidden_sizes=hidden_a, output_activation=config. actor_a_output_act, # activation on the thought vector h_t output_rescale=None).to(my_device()) actor_b = network( input_size=thought_size * 2, # assuming the integrated thought is stacked here output_size=num_actions, hidden_sizes=hidden_b, output_activation=config.output_activation, output_rescale=config.output_rescale).to(my_device()) return actor_a, actor_b, thought_size
def _update_network(self, batch: List[Transition]): assert isinstance(batch, List) assert isinstance(batch[0], Transition) assert isinstance(batch[0].action, torch.Tensor) # [batch_size, state_size] states = torch.stack([tr.state for tr in batch]).squeeze(1).squeeze( 1) # remove 1s [batch_size, 1, 1, data_s] new_states = torch.stack([tr.new_state for tr in batch]).squeeze(1).squeeze(1) # [batch_size, action_size] actions = torch.stack([tr.action for tr in batch]).squeeze(1).squeeze(1) # [batch_size, 1] rewards = torch.tensor([tr.reward for tr in batch], dtype=torch.float, device=my_device()).unsqueeze(-1) self.reset() # First, compute the actor loss actions_actor = self.actor.forward(states) action_vals = self.critic_t.forward( torch.cat((states, actions_actor), dim=1)) # TODO critic_t here?? # actor loss: the higher the action_value, the lower its loss value is => suppress bad actions all_actor_losses = -action_vals.mean() # Second, compute the critic loss with torch.no_grad(): next_actions = self.actor_t.forward(new_states) next_action_vals = self.critic_t.forward( torch.cat((new_states, next_actions), dim=1)) # Bellman eq. here target_vals = self.gamma * next_action_vals + rewards # TODO why is this correct? # compute what the critic was actually saying remembered_action_vals = self.critic.forward( torch.cat((states, actions), dim=1)) all_critic_losses = self.criterion(remembered_action_vals, target_vals).mean() # actor training here self.actor.zero_grad() all_actor_losses.backward() self.actor_optimizer.step() # critic training self.critic.zero_grad() all_critic_losses.backward() self.critic_optimizer.step() self.num_learns += 1 self.last_actor_loss = all_actor_losses.item() self.last_critic_loss = all_critic_losses.item() self.track_targets(self.tau)
def _make_state(position: float, one_hot_obs, size: int) -> torch.tensor: if one_hot_obs: state = SimpleEnvBase.to_one_hot(np.math.floor(position), size) else: state = np.array(position).reshape(1) state = torch.tensor( state, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(my_device()) return state
def remember(self, new_observation: np.array, reward: float, done: bool): """No change""" obs = torch.tensor(new_observation, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to( my_device()) self.buffer.remember(self.last_observation, self.last_action, obs, reward, done) self.append_reward(ReplayBuffer.sanitize_reward(reward))
def build_actor(input_size: int, num_actions: int, config: Namespace) -> NetworkBase: network = locate(config.network) actor = network(input_size=input_size, output_size=num_actions, hidden_sizes=config.hidden_sizes, output_activation=config.output_activation, output_rescale=config.output_rescale).to(my_device()) return actor
def build_critic(input_size: int, network: str, hidden_sizes: Union[int, List[int]]) -> NetworkBase: network = locate(network) critic = network(input_size=input_size, output_size=1, hidden_sizes=hidden_sizes, output_activation=None, output_rescale=None, softmaxed_parts=None).to(my_device()) return critic
def pick_action(self, observation: np.ndarray): obs = torch.tensor(observation.flatten(), dtype=torch.float32).unsqueeze(0).unsqueeze(0).to( my_device()) with torch.no_grad(): actions = self.actor.forward(obs) action = self.exploration.pick_action(actions) self.last_observation = obs self.last_action = action return action.to('cpu').numpy()
def _batch_to_tensors(self, batch: List[ATOCMultiagentTransition])\ -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Get the batch and convert to expected format for learning, usually: [batch_size, num_agents, data_size] for now Returns: states, agent_ids, actions, new_states, new_agent_ids, rewards, comm_matrix_batch """ assert isinstance(batch, List) assert isinstance(batch[0], ATOCMultiagentTransition) assert isinstance(batch[0].action, List) assert isinstance(batch[0].state[0], torch.Tensor) # Create a list of all states, then stack and reshape accordingly """ States, new states, actions, and rewards """ states = [] new_states = [] actions = [] rewards = [] for transition in batch: states.extend(transition.state) new_states.extend(transition.new_state) actions.extend(transition.action) rewards.extend(transition.reward) # Stack all the data we need complete_observations = torch.stack(states).view( self.batch_size, self.num_agents, self.input_size) new_complete_observations = torch.stack(new_states).view( self.batch_size, self.num_agents, self.input_size) actions = torch.stack(actions).view(self.batch_size, self.num_agents, self.output_size) rewards = torch.tensor(rewards, dtype=torch.float32, device=my_device()).view( self.batch_size, self.num_agents, 1) states, agent_ids = self._split_observation_tensor( complete_observations) new_states, new_agent_ids = self._split_observation_tensor( new_complete_observations) comm_matrix_batch = torch.stack( [transition.comm_matrix for transition in batch]) return states, agent_ids, actions, new_states, new_agent_ids, rewards, comm_matrix_batch
def build_networks(self, config: Namespace, _run: int, observation_space) \ -> Tuple[NetworkBase, NetworkBase, int, NetworkBase, SimpleFF]: # both parts of the actor actor_a, actor_b, thought_size = self.build_actor_nets( input_size=self.input_size, num_actions=self.output_size, config=config) # standard critic network critic = DDPGPolicy.build_critic(input_size=self.input_size + self.output_size, network=config.network, hidden_sizes=config.atoc_critic_sizes) # binary classifier for deciding the communication classifier = SimpleFF(input_size=thought_size, output_size=1, hidden_sizes=config.classifier_hidden_sizes, output_activation='sigmoid').to(my_device()) return actor_a, actor_b, thought_size, critic, classifier
def _env_observation_to_tensor( self, env_observation: List[np.ndarray]) -> torch.Tensor: """Get the observation from the environment, convert to torch.Tensor. Returns: tensor of sizes [batch_size=1, num_agents, input_size] """ assert isinstance(env_observation, List), 'observation should be list of arrays' assert len( env_observation ) == self.num_agents, 'the length of observation list incompatible with num agents!' for obs in env_observation: assert isinstance(obs, np.ndarray), 'list of arrays expected' assert obs.size == env_observation[ 0].size, 'inconsistent observation sizes' obs_array = np.concatenate(env_observation) result = torch.tensor(obs_array, dtype=torch.float32, device=my_device()).view(1, self.num_agents, -1) return result
def _update_network(self, batch: List[Transition]): assert isinstance(batch, List) assert isinstance(batch[0], Transition) tr = batch[0] assert isinstance(tr.state, torch.Tensor) assert isinstance(tr.new_state, torch.Tensor) assert isinstance(tr.action, torch.Tensor) assert len(tr.state.shape) == 3 assert len(tr.new_state.shape) == 3 assert len(tr.action.shape) == 3 assert tr.action.shape[2] == self.num_actions all_actor_losses = torch.zeros((1, 1, 1), dtype=torch.float, device=my_device()) all_critic_losses = torch.zeros(1, dtype=torch.float, device=my_device()) for tr in batch: # TODO the done flag not used self.reset() # DDPG learning: you made a state-action-next_state transition # # q1: how good was the action? # -use the critic # # q2: how good was the critic estimate? # -use the actor to make next_action from next_state (we cannot get the argmax_a Q(s,a)) # -use the critic to evaluate that Q(next_state, next_action) # -using the bellman e.q. to compute the loss based on TD between Q(state, action) and Q(next_s, next_a) # q1 action = self.actor.forward(tr.state) action_val = self.critic_t.forward( torch.cat((tr.state, action), dim=2)) # suppress the bad actions: higher action_val => lower loss # note: we are differentiating through the critic and then actor, # the actor_optimizer has just the actor params, so no update of the critic here all_actor_losses += -action_val # q2: with torch.no_grad(): next_action = self.actor_t.forward(tr.new_state) next_value = self.critic_t.forward( torch.cat((tr.new_state, next_action), dim=2)) # note: compared to the previous case, now we used the actor and critic in the future # to compute the targets for the critic. # Therefore no gradients in this part action_val_target = tr.reward + self.gamma * next_value # bellman equation.. remembered_action_val = self.critic.forward( torch.cat((tr.state, tr.action), dim=2)) # compute the output all_critic_losses += self.criterion(remembered_action_val, action_val_target) all_actor_losses = all_actor_losses / self.batch_size # take the optimization steps self.actor_optimizer.zero_grad() all_actor_losses.backward() self.actor_optimizer.step() self.critic_optimizer.zero_grad() all_critic_losses.backward() self.critic_optimizer.step() # post-training self.num_learns += 1 self.last_actor_loss = all_actor_losses.item() self.last_critic_loss = all_critic_losses.item() self.track_targets(self.tau)
def _update_network(self, batch: List[ATOCMultiagentTransition]): """ Convert batch to expected tensor format""" states, agent_ids, actions, new_states, new_agent_ids, rewards, comm_matrix = \ self._batch_to_tensors(batch) self.reset() # no batch size here """Evaluate the actor""" thoughts = self.actor_a.forward(states) if self.disable_communication: integrated_thoughts = torch.zeros_like(thoughts) else: integrated_thoughts = self.communication.communicate_batched( thoughts, agent_ids, comm_matrix) complete_thoughts = torch.cat([thoughts, integrated_thoughts], dim=-1) actions_actor = self.actor_b.forward(complete_thoughts) q_values = self.target_critic.forward( torch.cat([states, actions_actor], dim=-1)) actor_losses = -q_values.mean() / self.num_agents """Evaluate the critic""" critic_orig_output = self.critic.forward( torch.cat([states, actions], dim=-1)) with torch.no_grad(): # make action from the s' new_thoughts = self.target_actor_a.forward(new_states) if self.disable_communication: new_integrated_thoughts = torch.zeros_like(new_thoughts) else: new_integrated_thoughts = self.communication.communicate_batched( new_thoughts, new_agent_ids, comm_matrix) new_complete_thoughts = torch.cat( [new_thoughts, new_integrated_thoughts], dim=-1) new_actions = self.target_actor_b.forward(new_complete_thoughts) # compute the Q(s',a') new_q_values = self.target_critic.forward( torch.cat([new_states, new_actions], dim=-1)) critic_targets = rewards + self.gamma * new_q_values critic_losses = self.criterion(critic_orig_output, critic_targets).mean() / self.num_agents """Update actors and comm. channel if used""" # actor training here self.actor_a.zero_grad() if not self.disable_communication: self.communication_channel.zero_grad() self.actor_b.zero_grad() actor_losses.backward() self.actor_a_optim.step() if not self.disable_communication: self.communication_channel_optim.step() self.actor_b_optim.step() """Update the critic""" self.critic.zero_grad() critic_losses.backward() self.critic_optim.step() """Update the classifier""" if not self.disable_communication and not self.force_communication: thoughts, deltas = self.classifier_buffer.sample_normalized_batch( self.classifier_batch_size) outputs = self.classifier.forward(thoughts) classifier_loss = self.classifier_criterion(outputs, deltas) self.classifier.zero_grad() classifier_loss.backward() self.classifier_optim.step() else: classifier_loss = torch.zeros(1, device=my_device()) """Pos-training""" self.num_learns += 1 self.last_actor_loss = actor_losses.item() self.last_critic_loss = critic_losses.item() self.last_classifier_loss = classifier_loss.item() self.track_targets(self.tau)
def _numpy_to_tensors(self, data: List[np.ndarray]) -> List[torch.Tensor]: return [ torch.tensor(d.flatten(), dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(my_device()) \ for d in data ]
def __init__(self, observation_space: Union[List, gym.Space], action_space: Union[gym.Space, List], config: Namespace, _run=None, run_id: int = 0): super().__init__() """Params""" self._run = _run self.run_id = run_id self.gamma = config.gamma assert config.tau > 0, 'target networks need to be used' self.tau = config.tau assert config.batch_size > 0 self.batch_size = config.batch_size assert config.num_learning_iterations > 0 self.num_learning_iterations = config.num_learning_iterations self.comm_decision_period = config.comm_decision_period self.comm_bandwidth = config.comm_bandwidth self.disable_communication = config.disable_communication """IO sizes""" self.input_size, self.output_size = self.read_io_size( observation_space, action_space) self.num_agents = len(observation_space) self.num_perceived_agents = config.num_perceived_agents assert self.num_perceived_agents <= self.num_agents self.last_comm_matrix = torch.zeros(self.num_agents, self.num_agents, dtype=torch.bool, device=my_device()) """Build networks""" self.actor_a, self.actor_b, self.thought_size, self.critic, self.classifier = \ self.build_networks(config=config, _run=_run, observation_space=observation_space) self.target_actor_a, self.target_actor_b, _, self.target_critic, self.target_classifier = \ self.build_networks(config=config, _run=_run, observation_space=observation_space) # TODO use the target classifier """Build communication channel networks""" atoc_comm_class = locate(config.atoc_comm) self.communication = atoc_comm_class(self, config) self.communication_channel = self.communication.build_network() self.target_communication_channel = self.communication.build_network( ) # TODO use the target comm channel self.force_communication = config.force_communication self.track_targets(tau=1.0) """Build optimizers""" self.actor_a_optim = torch.optim.Adam(self.actor_a.parameters(), lr=config.lr) self.actor_b_optim = torch.optim.Adam(self.actor_b.parameters(), lr=config.lr) self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=config.critic_lr) self.classifier_optim = torch.optim.Adam(self.classifier.parameters(), lr=config.classifier_lr) self.communication_channel_optim = torch.optim.Adam( self.communication_channel.parameters(), lr=config.lr) """Build the buffers""" self.buffer = ReplayBufferMultiagentATOC( buffer_size=config.buffer_size, input_sizes=self.num_agents * [self.input_size], action_sizes=self.num_agents * [self.output_size], _run=_run) self.classifier_batch_size = config.classifier_batch_size self.classifier_buffer_size = config.classifier_buffer_size self.classifier_lr = config.classifier_lr self.classifier_buffer = ATOCClassifierBuffer( buffer_size=self.classifier_buffer_size, thought_size=self.thought_size) """Criterion""" self.criterion = torch.nn.SmoothL1Loss() self.classifier_criterion = torch.nn.BCELoss() """Agents""" self.agents = [ ATOCAgent(num_actions=self.output_size, config=config, agent_id=aid) for aid in range(self.num_agents) ] self.num_learns = 0 self.last_critic_loss = 0 self.last_actor_loss = 0 self.last_classifier_loss = 0
def compose_batch_tensors(self, batch: List[List[Transition]]) -> \ Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """Convert list of sequences to batch of sequences, pad everything to the same max_seq_length. Args: batch: result of the sample method (list of sequences) Returns [states, actions, next_states, rewards, not_dones]; -each thing is a list of Tensors, one Tensor for one position in the sequence (all sequences), -the tensor dimensions are: [batch_size, 1, data_size] """ assert isinstance(batch, List) assert isinstance(batch[0], List) assert isinstance(batch[0][0], Transition) assert isinstance(batch[0][0].action, torch.Tensor) # TODO support also integers here assert len( batch[0][0].state.shape ) == 3, 'expects the states to be tensors of shape [1, 1, state_size]' assert len( batch[0][0].new_state.shape ) == 3, 'expects the states to be tensors of shape [1, 1, state_size]' assert len( batch[0][0].action.shape ) == 3, 'expects the actions to be tensors of shape [1, 1, num_actions]' all_states = [] all_next_states = [] all_actions = [] all_rewards = [] all_not_dones = [] max_seq_len = ReplayBufferEpisodicSubsequence._find_max_len(batch) # go through each step of the episode(s) and stack tensors from the same step to batch_size dim for sequence_pos in range(max_seq_len): # [batch_size, 1, data_size] st = torch.stack([ self._get_state(episode, sequence_pos) for episode in batch ]).squeeze(1) # remove old batch_s=1 next_st = torch.stack([ self._get_next_state(episode, sequence_pos) for episode in batch ]).squeeze(1) # [batch_size, 1] act = torch.stack([ self._get_action(episode, sequence_pos) for episode in batch ]).squeeze(1) rew = torch.tensor( [self._get_reward(episode, sequence_pos) for episode in batch], device=my_device()).view(-1, 1, 1) # invert the flag for convenience not_dones = torch.tensor([ self._get_not_done(episode, sequence_pos) for episode in batch ], dtype=torch.float32, device=my_device()).view(-1, 1, 1) self._shift_not_done_flag(not_dones) # shift one step further all_states.append(st) all_next_states.append(next_st) all_actions.append(act) all_rewards.append(rew) all_not_dones.append(not_dones) # TODO in case the buff_seq_length is None, the reward from the last time step (of the rollout) # is not handled correctly by the RDDPG algorithms! # artificially introduce not DONE flag at the end of each sub-sequence all_not_dones[-1] = all_not_dones[-1] * 0 return all_states, all_actions, all_next_states, all_rewards, all_not_dones