Пример #1
0
    def _unpack_observation(self, obs_batch):
        """Unpacks the observation, action mask, and state (if present)
        from agent grouping.

        Returns:
            obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
            mask (np.ndarray): action mask, if any
            state (np.ndarray or None): state tensor of shape [B, state_size]
                or None if it is not in the batch
        """

        unpacked = _unpack_obs(
            np.array(obs_batch, dtype=np.float32),
            self.observation_space.original_space,
            tensorlib=np,
        )

        if isinstance(unpacked[0], dict):
            assert "obs" in unpacked[0]
            unpacked_obs = [
                np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
            ]
        else:
            unpacked_obs = unpacked

        obs = np.concatenate(unpacked_obs, axis=1).reshape(
            [len(obs_batch), self.n_agents, self.obs_size])

        if self.has_action_mask:
            action_mask = np.concatenate([o["action_mask"] for o in unpacked],
                                         axis=1).reshape([
                                             len(obs_batch), self.n_agents,
                                             self.n_actions
                                         ])
        else:
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions],
                dtype=np.float32)

        if self.has_env_global_state:
            state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
        else:
            state = None
        return obs, action_mask, state
Пример #2
0
def _unpack_general(obs_space, obs_keys, num_actions, obs, device=None):
    """Unpack observation in :obs: given arguments"""
    if isinstance(obs, dict):
        obs = obs['obs']

    if not isinstance(obs, torch.Tensor):
        obs = torch.as_tensor(obs, dtype=torch.float, device=device)

    unpacked = _unpack_obs(obs, obs_space, tensorlib=np)

    # Observation
    if isinstance(unpacked, list):
        n_agents = len(obs_space.spaces)
        # obs
        obs = torch.stack([
            torch.cat([
                u[k].reshape(len(obs), -1)
                for k in obs_keys if k != "signal" and k != "action_mask"
            ],
                      dim=-1) for u in unpacked
        ],
                          dim=1).reshape(len(obs), n_agents, -1)
        # mask
        default_mask = torch.as_tensor(np.ones(shape=(obs.size(0),
                                                      num_actions)),
                                       dtype=torch.float,
                                       device=obs.device)
        mask = torch.stack([
            u.get("action_mask", default_mask).reshape(len(obs), -1)
            for u in unpacked
        ],
                           dim=1)
        # state
        if unpacked[0].get("state", None) is not None:
            state = torch.stack([u.get("state").reshape(len(obs), -1) for u in unpacked], dim=1)\
                .reshape(len(obs), n_agents, -1)
        else:
            state = None

        # signals
        if unpacked[0].get("signal", None) is not None:
            signal = torch.stack([u.get("signal").reshape(len(obs), -1) for u in unpacked], dim=1)\
                .reshape(len(obs), n_agents, -1)
        else:
            signal = None

    else:
        obs = torch.cat([
            unpacked[k].reshape(len(obs), -1)
            for k in obs_keys if k != "signal" and k != "action_mask"
        ],
                        dim=-1)

        # Action mask
        default_mask = torch.as_tensor(np.ones(shape=(obs.size(0),
                                                      num_actions)),
                                       dtype=torch.float,
                                       device=obs.device)
        mask = unpacked.get("action_mask", default_mask)

        # State
        state = unpacked.get("state", None)

        # Signal
        signal = unpacked.get("signal", None)

    return obs, mask, state, signal
Пример #3
0
    def learn_on_batch(self, train_batch):
        # print(type(train_batch))
        # Turn the values into tensors
        # train_batch_tensor = self._lazy_tensor_dict(train_batch)
        # train_batch_tensor = train_batch_tensor
        # restore_original_dimensions()
        # print(train_batch_tensor.keys())
        # update the skill dynamics

        # Set Model to train mode.
        if self.model:
            self.model.train()
        if self.dynamics:
            self.dynamics.train()

        stats = defaultdict(int)
        if self.use_dynamics:
            c = 0
            for ep in range(self.dynamics_epochs):
                for mb in minibatches(
                        train_batch, self.minibatch_size
                ):  # minibatches(train_batch.copy(), self.minibatch_size)
                    c += 1
                    mb["is_training"] = True
                    minibatch = self._lazy_tensor_dict(mb)

                    obs = _unpack_obs(minibatch['obs'],
                                      self.model.options['orig_obs_space'],
                                      torch)
                    next_obs = _unpack_obs(
                        minibatch['new_obs'],
                        self.model.options['orig_obs_space'], torch)
                    dynamics_obs = obs['dynamics_obs']
                    next_dynamics_obs = next_obs['dynamics_obs'] - obs[
                        'dynamics_obs']
                    z = obs['z']

                    log_prob = self.dynamics.get_log_prob(dynamics_obs,
                                                          z,
                                                          next_dynamics_obs,
                                                          training=True)
                    dynamics_loss = -torch.mean(log_prob)
                    orth_loss = self.dynamics.orthogonal_regularization()
                    l2_loss = self.dynamics.l2_regularization()
                    if self.config['dynamics_orth_reg']:
                        dynamics_loss += orth_loss
                    if self.config['dynamics_l2_reg'] and not self.config[
                            'dynamics_spectral_norm']:
                        dynamics_loss += l2_loss
                    self.dynamics_opt.zero_grad()
                    dynamics_loss.backward()
                    if self.config['grad_clip']:
                        grad_norm = nn.utils.clip_grad_norm_(
                            self.dynamics.parameters(),
                            self.config['grad_clip'])
                    self.dynamics_opt.step()
                    stats['dynamics_loss'] += dynamics_loss.item()
                    stats['orth_loss'] += orth_loss.item()
                    stats['l2_loss'] += l2_loss.item()
            stats['dynamics_loss'] /= c
            stats['orth_loss'] /= c
            stats['l2_loss'] /= c

            self.dynamics.eval()
            # compute intrinsic reward
            with torch.no_grad():
                batch = self._lazy_tensor_dict(train_batch)
                obs = _unpack_obs(batch['obs'],
                                  self.model.options['orig_obs_space'], torch)
                next_obs = _unpack_obs(batch['new_obs'],
                                       self.model.options['orig_obs_space'],
                                       torch)
                z = obs['z']
                dynamics_obs = obs['dynamics_obs']
                next_dynamics_obs = next_obs['dynamics_obs'] - obs[
                    'dynamics_obs']

                dads_reward, info = self.dynamics.compute_reward(
                    dynamics_obs, z, next_dynamics_obs)
                dads_reward = self.config[
                    'dads_reward_scale'] * dads_reward.numpy()
                # # replace the reward column in train_batch
                # print(train_batch['rewards'].shape)
                train_batch['rewards'] = dads_reward
                stats['avg_dads_reward'] = dads_reward.mean()
                stats['num_skills_higher_prob'] = info['num_higher_prob']

        # calculate GAE for dads reward here?
        trajs = train_batch.split_by_episode()
        processed_trajs = []
        for traj in trajs:
            processed_trajs.append(compute_gae_for_sample_batch(self, traj))
        batch = SampleBatch.concat_samples(processed_trajs)

        # train_batch = compute_gae_for_sample_batch(self, self._lazy_numpy_dict(train_batch))
        # train_batch = self._lazy_tensor_dict(train_batch)
        # update agent using RL algo
        # split to minibatches
        c = 0
        for ep in range(self.ppo_epochs):
            # batch.shuffle()
            for mb in minibatches(batch, self.minibatch_size):
                c += 1
                mb["is_training"] = True
                # minibatch = mb.copy()
                mb['advantages'] = standardize(mb['advantages'])
                minibatch = self._lazy_tensor_dict(mb)
                # compute the loss
                loss_out = ppo_surrogate_loss(self, self.model,
                                              self.dist_class, minibatch)
                # compute gradient
                self.ppo_opt.zero_grad()
                # the learning_rate is already used in ppo_surrogate_loss
                loss_out.backward()
                # grad norm
                if self.config['grad_clip']:
                    grad_norm = nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config['grad_clip'])
                self.ppo_opt.step()
                # log stats
                stats['ppo_loss'] += loss_out.item()
        stats['ppo_loss'] /= c
        # add more info about the loss
        stats.update(kl_and_loss_stats(self, train_batch))

        #  {
        #     "loss": loss_out.item(),
        #     'test': 1
        #     # "grad_norm": grad_norm
        #     # if isinstance(grad_norm, float) else grad_norm.item(),
        # }
        return {LEARNER_STATS_KEY: stats}