Ejemplo n.º 1
0
def test_export_cpp(env, policy: Policy, tmpdir, file_type):
    # Generate scripted version (in double mode for CPP compatibility)
    scripted = policy.double().script()

    # Export
    export_file = osp.join(tmpdir, 'policy' + file_type)
    scripted.save(export_file)

    # Import again
    loaded = to.jit.load(export_file)

    # Compare a couple of inputs
    for i in range(50):
        obs = policy.env_spec.obs_space.sample_uniform()
        act_scripted = scripted(to.from_numpy(obs)).cpu().numpy()
        act_loaded = loaded(to.from_numpy(obs)).cpu().numpy()
        assert act_loaded == pytest.approx(act_scripted), f"Wrong action values on step #{i}"

    # Test after reset
    if hasattr(scripted, 'reset'):
        scripted.reset()
        loaded.reset()
        assert loaded.hidden.numpy() == pytest.approx(scripted.hidden.numpy()), "Wrong hidden state after reset"

        obs = policy.env_spec.obs_space.sample_uniform()
        act_scripted = scripted(to.from_numpy(obs)).numpy()
        act_loaded = loaded(to.from_numpy(obs)).numpy()
        assert act_loaded == pytest.approx(act_scripted), "Wrong action values after reset"
Ejemplo n.º 2
0
def test_export_rcspysim(env: Env, policy: Policy, tmpdir: str):
    from rcsenv import ControlPolicy

    # Generate scripted version (double mode for CPP compatibility)
    scripted = policy.double().script()
    print(scripted.graph)

    # Export
    export_file = osp.join(tmpdir, "policy.pt")
    to.jit.save(scripted, export_file)

    # Import in C
    cpp = ControlPolicy("torch", export_file)

    # Compare a couple of inputs
    for _ in range(50):
        obs = policy.env_spec.obs_space.sample_uniform()
        obs = to.from_numpy(obs).to(dtype=to.double)
        act_script = scripted(obs).cpu().numpy()
        act_cpp = cpp(obs, policy.env_spec.act_space.flat_dim)
        assert act_cpp == pytest.approx(act_script)

    # Test after reset
    if hasattr(scripted, "reset"):
        scripted.reset()
        cpp.reset()
        obs = policy.env_spec.obs_space.sample_uniform()
        obs = to.from_numpy(obs).to(dtype=to.double)
        act_script = scripted(obs).cpu().numpy()
        act_cpp = cpp(obs, policy.env_spec.act_space.flat_dim)
        assert act_cpp == pytest.approx(act_script)
Ejemplo n.º 3
0
def test_script_recurrent(env: Env, policy: Policy):
    # Generate scripted version
    scripted = policy.double().script()

    # Compare results, tracing hidden manually
    hidden = policy.init_hidden()

    # Run one step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
    # Run second step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)

    # Test after reset
    hidden = policy.init_hidden()
    scripted.reset()

    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
Ejemplo n.º 4
0
def cpp_export(
    save_dir: pyrado.PathLike,
    policy: Policy,
    env: Optional[SimEnv] = None,
    policy_export_name: str = "policy_export",
    write_policy_node: bool = True,
    policy_node_name: str = "policy",
):
    """
    Convenience function to export the policy using PyTorch's scripting or tracing, and the experiment's XML
    configuration if the environment from RcsPySim.

    :param save_dir: directory to save in
    :param policy: (trained) policy
    :param env: environment the policy was trained in
    :param policy_export_name: name of the exported policy file without the file type ending
    :param write_policy_node: if `True`, write the PyTorch-based control policy into the experiment's XML configuration.
                              This requires the experiment's XML configuration to be exported beforehand.
    :param policy_node_name: name of the control policies node in the XML file, e.g. 'policy' or 'preStrikePolicy'
    """
    if not osp.isdir(save_dir):
        raise pyrado.PathErr(given=save_dir)
    if not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if not isinstance(policy_export_name, str):
        raise pyrado.TypeErr(given=policy_export_name, expected_type=str)

    # Use torch.jit.trace / torch.jit.script (the latter if recurrent) to generate a torch.jit.ScriptModule
    ts_module = policy.double().script(
    )  # can be evaluated like a regular PyTorch module

    # Serialize the script module to a file and save it in the same directory we loaded the policy from
    policy_export_file = osp.join(save_dir, f"{policy_export_name}.pt")
    ts_module.save(policy_export_file)  # former .zip, and before that .pth
    print_cbt(f"Exported the loaded policy to {policy_export_file}",
              "g",
              bright=True)

    # Export the experiment config for C++
    exp_export_file = osp.join(save_dir, "ex_config_export.xml")
    if env is not None and isinstance(inner_env(env), RcsSim):
        inner_env(env).save_config_xml(exp_export_file)
        print_cbt(f"Exported experiment configuration to {exp_export_file}",
                  "g",
                  bright=True)

    # Open the XML file again to add the policy node
    if write_policy_node and osp.isfile(exp_export_file):
        tree = et.parse(exp_export_file)
        root = tree.getroot()
        policy_node = et.Element(policy_node_name)
        policy_node.set("type", "torch")
        policy_node.set("file", f"{policy_export_name}.pt")
        root.append(policy_node)
        tree.write(exp_export_file)
        print_cbt(
            f"Added {policy_export_name}.pt to the experiment configuration.",
            "g")
Ejemplo n.º 5
0
def test_script_nonrecurrent(env: Env, policy: Policy):
    # Generate scripted version
    scripted = policy.double().script()

    # Compare results
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg = policy(obs)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)