def test_stepsequence_from_pandas(mock_data, given_rewards: bool): rewards, states, observations, actions, hidden, policy_infos = mock_data states = np.asarray(states) observations = np.asarray(observations) actions = to.stack(actions).numpy() rewards = np.asarray(rewards) # Create fake observed data set. The labels must match the labels of the spaces. The order can be mixed. content = dict( s0=states[:, 0], s1=states[:, 1], s2=states[:, 2], o3=observations[:, 3], o0=observations[:, 0], o2=observations[:, 2], o1=observations[:, 1], a1=actions[:, 1], a0=actions[:, 0], # Some content that was not in steps=np.arange(0, states.shape[0]), infos=[dict(foo="bar")] * 6, ) if given_rewards: content["rewards"] = rewards df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in content.items()])) env = MockEnv( state_space=InfBoxSpace(shape=states[0].shape, labels=["s0", "s1", "s2"]), obs_space=InfBoxSpace(shape=observations[0].shape, labels=["o0", "o1", "o2", "o3"]), act_space=InfBoxSpace(shape=actions[0].shape, labels=["a0", "a1"]), ) reconstructed = StepSequence.from_pandas(df, env.spec) assert len(reconstructed.rewards) == len(rewards) assert np.allclose(reconstructed.states, states) assert np.allclose(reconstructed.observations, observations) assert np.allclose(reconstructed.actions, actions)
if __name__ == "__main__": # Parse command line arguments args = get_argparser().parse_args() if not osp.isfile(args.file): raise pyrado.PathErr(given=args.file) if args.dir is None: # Use the file's directory by default args.dir = osp.dirname(args.file) elif not osp.isdir(args.dir): raise pyrado.PathErr(given=args.dir) df = pd.read_csv(args.file) if args.env_name == MiniGolfIKSim.name: env = MiniGolfIKSim() elif args.env_name == MiniGolfJointCtrlSim.name: env = MiniGolfJointCtrlSim() else: raise NotImplementedError # Cast the rollout from a DataFrame to a StepSequence reconstructed = StepSequence.from_pandas(df, env.spec, task=env.task) if args.dir is not None: suffix = args.file[args.file.rfind("/") + 1:-4] pyrado.save(reconstructed, f"rollout_{suffix}.pkl", args.dir, verbose=True)