def em(self, obs, act=None, nb_iter=50, tol=1e-4, initialize=True, ctl_mstep_kwargs={}, **kwargs): from sds.utils.general import train_test_split train_obs, train_act = train_test_split( obs, act, nb_traj_splits=self.ensemble_size, split_trajs=False)[:2] self.models, lls = self._parallel_em(train_obs, train_act, nb_iter=nb_iter, tol=tol, initialize=initialize, ctl_mstep_kwargs=ctl_mstep_kwargs) nb_train = [np.vstack(x).shape[0] for x in train_obs] nb_total = np.vstack(obs).shape[0] train_ll, total_ll = [], [] for x, u, m in zip(train_obs, train_act, self.models): train_ll.append(m.log_normalizer(x, u)) total_ll.append(m.log_normalizer(obs, act)) train_scores = np.hstack(train_ll) / np.hstack(nb_train) test_scores = (np.hstack(total_ll) - np.hstack(train_ll))\ / (nb_total - np.hstack(nb_train)) return train_scores, test_scores
env.unwrapped.sigma = 1e-4 env.unwrapped.uniform = True env.seed(1337) from stable_baselines import SAC _ctl = SAC.load("./sac_pendulum") sac_ctl = lambda x: _ctl.predict(x)[0] nb_rollouts, nb_steps = 50, 200 obs, act = sample_env(env, nb_rollouts, nb_steps, sac_ctl, noise_std=1e-2) from sds.utils.general import train_test_split train_obs, train_act, _, _ = train_test_split(obs, act, seed=3, nb_traj_splits=6, split_trajs=False) fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) fig.suptitle('Pendulum SAC Demonstrations') for _obs, _act in zip(obs, act): # angle = np.arctan2(_obs[:, 1], _obs[:, 0]) # axs[0].plot(angle) axs[0].plot(_obs[:, 0]) axs[0] = beautify(axs[0]) axs[1].plot(_obs[:, -1]) axs[1] = beautify(axs[1]) axs[2].plot(_act) axs[2] = beautify(axs[2])