def get_loss(self):
     x = torch.cat([self.features, self.next_features], 2)
     sh = x.shape
     x = flatten_dims(x, 1)
     param = self.fc(x)
     idfpd = self.policy.ac_pdtype.pdfromflat(param)
     ac = flatten_dims(self.ac, len(self.ac_space.shape))
     return idfpd.neglogp(torch.tensor(ac))
Esempio n. 2
0
    def get_loss(self):
        ac = self.ac
        sh = ac.shape
        ac = flatten_dims(ac, len(self.ac_space.shape))
        ac = torch.zeros(ac.shape + (self.ac_space.n, )).scatter_(
            1,
            torch.tensor(ac).unsqueeze(1),
            1)  # one_hot(self.ac, self.ac_space.n, axis=2)
        ac = unflatten_first_dim(ac, sh)

        features = self.features
        next_features = self.next_features
        assert features.shape[:-1] == ac.shape[:-1]
        sh = features.shape
        x = flatten_dims(features, 1)
        ac = flatten_dims(ac, 1)
        x = self.loss_net(x, ac)
        x = unflatten_first_dim(x, sh)
        return torch.mean((x - next_features)**2, -1)
 def get_features(self, x):
     x_has_timesteps = (len(x.shape) == 5)
     if x_has_timesteps:
         sh = x.shape
         x = flatten_dims(x, len(self.ob_space.shape))
     x = (x - self.ob_mean) / self.ob_std
     x = np.transpose(x, [i for i in range(len(x.shape) - 3)] +
                      [-1, -3, -2])  # transpose channel axis
     x = self.features_model(torch.tensor(x))
     if x_has_timesteps:
         x = unflatten_first_dim(x, sh)
     return x
Esempio n. 4
0
 def get_features(self, x):
     x_has_timesteps = (x.get_shape().ndims == 5)
     if x_has_timesteps:
         sh = x.shape
         x = flatten_dims(x, self.ob_space.n)
     x = np.transpose(x,
                      [i for i in range(len(x.shape) - 3)] + [-1, -3, -2])
     x = (x - self.ob_mean) / self.ob_std
     x = self.features_model(x)
     if x_has_timesteps:
         x = unflatten_first_dim(x, sh)
     return x
Esempio n. 5
0
 def update_features(self, ob, ac):
     sh = ob.shape  # ob.shape = [nenvs, timestep, H, W, C]. Can timestep > 1 ?
     x = flatten_dims(
         ob, len(self.ob_space.shape)
     )  # flat first two dims of ob.shape and get a shape of [N, H, W, C].
     flat_features = self.get_features(x)  # [N, feat_dim]
     self.flat_features = flat_features
     hidden = self.pd_hidden(flat_features)
     pdparam = self.pd_head(hidden)
     vpred = self.vf_head(hidden)
     self.vpred = unflatten_first_dim(vpred, sh)  #[nenvs, tiemstep, v]
     self.pd = pd = self.ac_pdtype.pdfromflat(pdparam)
     self.ac = ac
     self.ob = ob
 def decoder(self, z):
     z_has_timesteps = (len(z.shape) == 3)
     if z_has_timesteps:
         sh = z.shape
         z = flatten_dims(z, 1)
     z = self.decoder_model(z)
     if z_has_timesteps:
         z = unflatten_first_dim(z, sh)
     if self.spherical_obs:
         scale = torch.max(self.scale, torch.tensor(-4.0))
         scale = torch.nn.functional.softplus(scale)
         scale = scale * torch.ones(z.shape)
     else:
         z, scale = torch.split(z, [4, 4], -3)
         scale = torch.nn.functional.softplus(scale)
     return torch.distributions.normal.Normal(z, scale)
Esempio n. 7
0
    def update(self):
        if self.normrew:
            rffs = np.array(
                [self.rff.update(rew) for rew in self.rollout.buf_rews.T])
            rffs_mean, rffs_std, rffs_count = mpi_moments(rffs.ravel())
            self.rff_rms.update_from_moments(rffs_mean, rffs_std**2,
                                             rffs_count)
            rews = self.rollout.buf_rews / np.sqrt(self.rff_rms.var)
        else:
            rews = np.copy(self.rollout.buf_rews)
        self.calculate_advantages(rews=rews,
                                  use_news=self.use_news,
                                  gamma=self.gamma,
                                  lam=self.lam)

        info = dict(advmean=self.buf_advs.mean(),
                    advstd=self.buf_advs.std(),
                    retmean=self.buf_rets.mean(),
                    retstd=self.buf_rets.std(),
                    vpredmean=self.rollout.buf_vpreds.mean(),
                    vpredstd=self.rollout.buf_vpreds.std(),
                    ev=explained_variance(self.rollout.buf_vpreds.ravel(),
                                          self.buf_rets.ravel()),
                    rew_mean=np.mean(self.rollout.buf_rews),
                    recent_best_ext_ret=self.rollout.current_max)
        if self.rollout.best_ext_ret is not None:
            info['best_ext_ret'] = self.rollout.best_ext_ret

        to_report = {
            'total': 0.0,
            'pg': 0.0,
            'vf': 0.0,
            'ent': 0.0,
            'approxkl': 0.0,
            'clipfrac': 0.0,
            'aux': 0.0,
            'dyn_loss': 0.0,
            'feat_var': 0.0
        }

        # normalize advantages
        if self.normadv:
            m, s = get_mean_and_std(self.buf_advs)
            self.buf_advs = (self.buf_advs - m) / (s + 1e-7)
        envsperbatch = (self.nenvs * self.nsegs_per_env) // self.nminibatches
        envsperbatch = max(1, envsperbatch)
        envinds = np.arange(self.nenvs * self.nsegs_per_env)

        mblossvals = []

        for _ in range(self.nepochs):
            np.random.shuffle(envinds)
            for start in range(0, self.nenvs * self.nsegs_per_env,
                               envsperbatch):
                end = start + envsperbatch
                mbenvinds = envinds[start:end]

                acs = self.rollout.buf_acs[mbenvinds]
                rews = self.rollout.buf_rews[mbenvinds]
                vpreds = self.rollout.buf_vpreds[mbenvinds]
                nlps = self.rollout.buf_nlps[mbenvinds]
                obs = self.rollout.buf_obs[mbenvinds]
                rets = self.buf_rets[mbenvinds]
                advs = self.buf_advs[mbenvinds]
                last_obs = self.rollout.buf_obs_last[mbenvinds]

                lr = self.lr
                cliprange = self.cliprange

                self.stochpol.update_features(obs, acs)
                self.dynamics.auxiliary_task.update_features(obs, last_obs)
                self.dynamics.update_features(obs, last_obs)

                feat_loss = torch.mean(self.dynamics.auxiliary_task.get_loss())
                dyn_loss = torch.mean(self.dynamics.get_loss())

                acs = torch.tensor(flatten_dims(acs, len(self.ac_space.shape)))
                neglogpac = self.stochpol.pd.neglogp(acs)
                entropy = torch.mean(self.stochpol.pd.entropy())
                vpred = self.stochpol.vpred
                vf_loss = 0.5 * torch.mean(
                    (vpred.squeeze() - torch.tensor(rets))**2)

                nlps = torch.tensor(flatten_dims(nlps, 0))
                ratio = torch.exp(nlps - neglogpac.squeeze())

                advs = flatten_dims(advs, 0)
                negadv = torch.tensor(-advs)
                pg_losses1 = negadv * ratio
                pg_losses2 = negadv * torch.clamp(
                    ratio, min=1.0 - cliprange, max=1.0 + cliprange)
                pg_loss_surr = torch.max(pg_losses1, pg_losses2)
                pg_loss = torch.mean(pg_loss_surr)
                ent_loss = (-self.ent_coef) * entropy

                approxkl = 0.5 * torch.mean((neglogpac - nlps)**2)
                clipfrac = torch.mean(
                    (torch.abs(pg_losses2 - pg_loss_surr) > 1e-6).float())
                feat_var = torch.std(self.dynamics.auxiliary_task.features)

                total_loss = pg_loss + ent_loss + vf_loss + feat_loss + dyn_loss

                total_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                to_report['total'] += total_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['pg'] += pg_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['vf'] += vf_loss.data.numpy() / (self.nminibatches *
                                                           self.nepochs)
                to_report['ent'] += ent_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['approxkl'] += approxkl.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['clipfrac'] += clipfrac.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['feat_var'] += feat_var.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['aux'] += feat_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)
                to_report['dyn_loss'] += dyn_loss.data.numpy() / (
                    self.nminibatches * self.nepochs)

        info.update(to_report)
        self.n_updates += 1
        info["n_updates"] = self.n_updates
        info.update({
            dn: (np.mean(dvs) if len(dvs) > 0 else 0)
            for (dn, dvs) in self.rollout.statlists.items()
        })
        info.update(self.rollout.stats)
        if "states_visited" in info:
            info.pop("states_visited")
        tnow = time.time()
        info["ups"] = 1. / (tnow - self.t_last_update)
        info["total_secs"] = tnow - self.t_start
        info['tps'] = self.rollout.nsteps * self.nenvs / (
            tnow - self.t_last_update)  # MPI.COMM_WORLD.Get_size() *
        self.t_last_update = tnow

        return info