def _unpack_observation(self, obs_batch): """Unpacks the observation, action mask, and state (if present) from agent grouping. Returns: obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] mask (np.ndarray): action mask, if any state (np.ndarray or None): state tensor of shape [B, state_size] or None if it is not in the batch """ unpacked = _unpack_obs( np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np, ) if isinstance(unpacked[0], dict): assert "obs" in unpacked[0] unpacked_obs = [ np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked ] else: unpacked_obs = unpacked obs = np.concatenate(unpacked_obs, axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size]) if self.has_action_mask: action_mask = np.concatenate([o["action_mask"] for o in unpacked], axis=1).reshape([ len(obs_batch), self.n_agents, self.n_actions ]) else: action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32) if self.has_env_global_state: state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1) else: state = None return obs, action_mask, state
def _unpack_general(obs_space, obs_keys, num_actions, obs, device=None): """Unpack observation in :obs: given arguments""" if isinstance(obs, dict): obs = obs['obs'] if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float, device=device) unpacked = _unpack_obs(obs, obs_space, tensorlib=np) # Observation if isinstance(unpacked, list): n_agents = len(obs_space.spaces) # obs obs = torch.stack([ torch.cat([ u[k].reshape(len(obs), -1) for k in obs_keys if k != "signal" and k != "action_mask" ], dim=-1) for u in unpacked ], dim=1).reshape(len(obs), n_agents, -1) # mask default_mask = torch.as_tensor(np.ones(shape=(obs.size(0), num_actions)), dtype=torch.float, device=obs.device) mask = torch.stack([ u.get("action_mask", default_mask).reshape(len(obs), -1) for u in unpacked ], dim=1) # state if unpacked[0].get("state", None) is not None: state = torch.stack([u.get("state").reshape(len(obs), -1) for u in unpacked], dim=1)\ .reshape(len(obs), n_agents, -1) else: state = None # signals if unpacked[0].get("signal", None) is not None: signal = torch.stack([u.get("signal").reshape(len(obs), -1) for u in unpacked], dim=1)\ .reshape(len(obs), n_agents, -1) else: signal = None else: obs = torch.cat([ unpacked[k].reshape(len(obs), -1) for k in obs_keys if k != "signal" and k != "action_mask" ], dim=-1) # Action mask default_mask = torch.as_tensor(np.ones(shape=(obs.size(0), num_actions)), dtype=torch.float, device=obs.device) mask = unpacked.get("action_mask", default_mask) # State state = unpacked.get("state", None) # Signal signal = unpacked.get("signal", None) return obs, mask, state, signal
def learn_on_batch(self, train_batch): # print(type(train_batch)) # Turn the values into tensors # train_batch_tensor = self._lazy_tensor_dict(train_batch) # train_batch_tensor = train_batch_tensor # restore_original_dimensions() # print(train_batch_tensor.keys()) # update the skill dynamics # Set Model to train mode. if self.model: self.model.train() if self.dynamics: self.dynamics.train() stats = defaultdict(int) if self.use_dynamics: c = 0 for ep in range(self.dynamics_epochs): for mb in minibatches( train_batch, self.minibatch_size ): # minibatches(train_batch.copy(), self.minibatch_size) c += 1 mb["is_training"] = True minibatch = self._lazy_tensor_dict(mb) obs = _unpack_obs(minibatch['obs'], self.model.options['orig_obs_space'], torch) next_obs = _unpack_obs( minibatch['new_obs'], self.model.options['orig_obs_space'], torch) dynamics_obs = obs['dynamics_obs'] next_dynamics_obs = next_obs['dynamics_obs'] - obs[ 'dynamics_obs'] z = obs['z'] log_prob = self.dynamics.get_log_prob(dynamics_obs, z, next_dynamics_obs, training=True) dynamics_loss = -torch.mean(log_prob) orth_loss = self.dynamics.orthogonal_regularization() l2_loss = self.dynamics.l2_regularization() if self.config['dynamics_orth_reg']: dynamics_loss += orth_loss if self.config['dynamics_l2_reg'] and not self.config[ 'dynamics_spectral_norm']: dynamics_loss += l2_loss self.dynamics_opt.zero_grad() dynamics_loss.backward() if self.config['grad_clip']: grad_norm = nn.utils.clip_grad_norm_( self.dynamics.parameters(), self.config['grad_clip']) self.dynamics_opt.step() stats['dynamics_loss'] += dynamics_loss.item() stats['orth_loss'] += orth_loss.item() stats['l2_loss'] += l2_loss.item() stats['dynamics_loss'] /= c stats['orth_loss'] /= c stats['l2_loss'] /= c self.dynamics.eval() # compute intrinsic reward with torch.no_grad(): batch = self._lazy_tensor_dict(train_batch) obs = _unpack_obs(batch['obs'], self.model.options['orig_obs_space'], torch) next_obs = _unpack_obs(batch['new_obs'], self.model.options['orig_obs_space'], torch) z = obs['z'] dynamics_obs = obs['dynamics_obs'] next_dynamics_obs = next_obs['dynamics_obs'] - obs[ 'dynamics_obs'] dads_reward, info = self.dynamics.compute_reward( dynamics_obs, z, next_dynamics_obs) dads_reward = self.config[ 'dads_reward_scale'] * dads_reward.numpy() # # replace the reward column in train_batch # print(train_batch['rewards'].shape) train_batch['rewards'] = dads_reward stats['avg_dads_reward'] = dads_reward.mean() stats['num_skills_higher_prob'] = info['num_higher_prob'] # calculate GAE for dads reward here? trajs = train_batch.split_by_episode() processed_trajs = [] for traj in trajs: processed_trajs.append(compute_gae_for_sample_batch(self, traj)) batch = SampleBatch.concat_samples(processed_trajs) # train_batch = compute_gae_for_sample_batch(self, self._lazy_numpy_dict(train_batch)) # train_batch = self._lazy_tensor_dict(train_batch) # update agent using RL algo # split to minibatches c = 0 for ep in range(self.ppo_epochs): # batch.shuffle() for mb in minibatches(batch, self.minibatch_size): c += 1 mb["is_training"] = True # minibatch = mb.copy() mb['advantages'] = standardize(mb['advantages']) minibatch = self._lazy_tensor_dict(mb) # compute the loss loss_out = ppo_surrogate_loss(self, self.model, self.dist_class, minibatch) # compute gradient self.ppo_opt.zero_grad() # the learning_rate is already used in ppo_surrogate_loss loss_out.backward() # grad norm if self.config['grad_clip']: grad_norm = nn.utils.clip_grad_norm_( self.model.parameters(), self.config['grad_clip']) self.ppo_opt.step() # log stats stats['ppo_loss'] += loss_out.item() stats['ppo_loss'] /= c # add more info about the loss stats.update(kl_and_loss_stats(self, train_batch)) # { # "loss": loss_out.item(), # 'test': 1 # # "grad_norm": grad_norm # # if isinstance(grad_norm, float) else grad_norm.item(), # } return {LEARNER_STATS_KEY: stats}