Пример #1
0
def test_stepsequence_from_pandas(mock_data, given_rewards: bool):
    rewards, states, observations, actions, hidden, policy_infos = mock_data
    states = np.asarray(states)
    observations = np.asarray(observations)
    actions = to.stack(actions).numpy()
    rewards = np.asarray(rewards)

    # Create fake observed data set. The labels must match the labels of the spaces. The order can be mixed.
    content = dict(
        s0=states[:, 0],
        s1=states[:, 1],
        s2=states[:, 2],
        o3=observations[:, 3],
        o0=observations[:, 0],
        o2=observations[:, 2],
        o1=observations[:, 1],
        a1=actions[:, 1],
        a0=actions[:, 0],
        # Some content that was not in
        steps=np.arange(0, states.shape[0]),
        infos=[dict(foo="bar")] * 6,
    )
    if given_rewards:
        content["rewards"] = rewards
    df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in content.items()]))

    env = MockEnv(
        state_space=InfBoxSpace(shape=states[0].shape,
                                labels=["s0", "s1", "s2"]),
        obs_space=InfBoxSpace(shape=observations[0].shape,
                              labels=["o0", "o1", "o2", "o3"]),
        act_space=InfBoxSpace(shape=actions[0].shape, labels=["a0", "a1"]),
    )

    reconstructed = StepSequence.from_pandas(df, env.spec)

    assert len(reconstructed.rewards) == len(rewards)
    assert np.allclose(reconstructed.states, states)
    assert np.allclose(reconstructed.observations, observations)
    assert np.allclose(reconstructed.actions, actions)
if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()
    if not osp.isfile(args.file):
        raise pyrado.PathErr(given=args.file)
    if args.dir is None:
        # Use the file's directory by default
        args.dir = osp.dirname(args.file)
    elif not osp.isdir(args.dir):
        raise pyrado.PathErr(given=args.dir)

    df = pd.read_csv(args.file)

    if args.env_name == MiniGolfIKSim.name:
        env = MiniGolfIKSim()
    elif args.env_name == MiniGolfJointCtrlSim.name:
        env = MiniGolfJointCtrlSim()
    else:
        raise NotImplementedError

    # Cast the rollout from a DataFrame to a StepSequence
    reconstructed = StepSequence.from_pandas(df, env.spec, task=env.task)

    if args.dir is not None:
        suffix = args.file[args.file.rfind("/") + 1:-4]
        pyrado.save(reconstructed,
                    f"rollout_{suffix}.pkl",
                    args.dir,
                    verbose=True)