Exemple #1
0
def test_script_recurrent(env: Env, policy: Policy):
    # Generate scripted version
    scripted = policy.double().script()

    # Compare results, tracing hidden manually
    hidden = policy.init_hidden()

    # Run one step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
    # Run second step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)

    # Test after reset
    hidden = policy.init_hidden()
    scripted.reset()

    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
Exemple #2
0
def test_recurrent_policy_one_step(env: Env, policy: Policy):
    assert policy.is_recurrent
    obs = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout & co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden()
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden = to.rand_like(hidden)
    assert len(hidden) == policy.hidden_size

    # Test general conformity
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, otherhead, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size
    else:
        act, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size

    # Test reproducibility
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act2, otherhead2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(otherhead, otherhead2)
        to.testing.assert_allclose(hid_new2, hid_new2)
    else:
        act2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(hid_new2, hid_new2)
Exemple #3
0
def test_recurrent_policy_batching(env: Env, policy: Policy, batch_size: int):
    assert policy.is_recurrent
    obs = np.stack([
        policy.env_spec.obs_space.sample_uniform() for _ in range(batch_size)
    ])  # shape = (batch_size, 4)
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout&co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden(batch_size)
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden.random_()
    assert hidden.shape == (batch_size, policy.hidden_size)

    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, _, hid_new = policy(obs, hidden)
    else:
        act, hid_new = policy(obs, hidden)
    assert hid_new.shape == (batch_size, policy.hidden_size)

    if batch_size > 1:
        # Try to use a subset of the batch
        subset = to.arange(batch_size // 2)
        if isinstance(policy, TwoHeadedRNNPolicyBase):
            act_sub, _, hid_sub = policy(obs[subset, :], hidden[subset, :])
        else:
            act_sub, hid_sub = policy(obs[subset, :], hidden[subset, :])
        to.testing.assert_allclose(act_sub, act[subset, :])
        to.testing.assert_allclose(hid_sub, hid_new[subset, :])
    def evaluate(policy: Policy,
                 inps: to.Tensor,
                 targs: to.Tensor,
                 windowed: bool,
                 cascaded: bool,
                 num_init_samples: int,
                 hidden: Optional[to.Tensor] = None,
                 loss_fcn=nn.MSELoss(),
                 verbose: bool = True):
        if not inps.shape[0] == targs.shape[0]:
            raise pyrado.ShapeErr(given=inps, expected_match=targs)

        # Set policy, i.e. PyTorch nn.Module, back to evaluation mode
        policy.eval()

        targs = targs[num_init_samples:, :] if num_init_samples > 0 else targs
        preds = to.empty_like(targs)

        # Pass the first samples through the network in order to initialize the hidden state
        inp = inps[:num_init_samples, :] if num_init_samples > 0 else inps[
            0].unsqueeze(0)  # running input
        pred, hidden = TSPred.predict(policy,
                                      inp,
                                      windowed,
                                      cascaded=False,
                                      hidden=hidden)

        # Run steps consecutively reusing the hidden state
        for idx in range(inps.shape[0] - num_init_samples):
            if not cascaded or idx == 0:
                # Forget the oldest input and append the latest input
                inp = inps[idx + num_init_samples, :].unsqueeze(0)
            else:
                # Forget the oldest input and append the latest prediction
                inp = pred

            pred, hidden = TSPred.predict(policy,
                                          inp,
                                          windowed,
                                          cascaded=False,
                                          hidden=hidden)
            preds[idx, :] = pred

        # Compute loss for the entire data set at once
        loss = loss_fcn(targs, preds)

        if verbose:
            print_cbt(
                f'The {policy.name} policy with {policy.num_param} parameters predicted {inps.shape[0]} data points '
                f'with a loss of {loss.item():.4e}.', 'g')

        # Set policy, i.e. PyTorch nn.Module, back to training mode
        policy.train()

        return preds, loss
def test_actor_critic(ex_dir, env: SimEnv, policy: Policy, algo, algo_hparam,
                      vfcn_type, use_cuda):
    pyrado.set_seed(0)

    if use_cuda:
        policy._device = "cuda"
        policy = policy.to(device="cuda")

    # Create value function
    if vfcn_type == "fnn-plain":
        vfcn = FNN(
            input_size=env.obs_space.flat_dim,
            output_size=1,
            hidden_sizes=[16, 16],
            hidden_nonlin=to.tanh,
            use_cuda=use_cuda,
        )
    elif vfcn_type == FNNPolicy.name:
        vf_spec = EnvSpec(env.obs_space, ValueFunctionSpace)
        vfcn = FNNPolicy(vf_spec,
                         hidden_sizes=[16, 16],
                         hidden_nonlin=to.tanh,
                         use_cuda=use_cuda)
    elif vfcn_type == RNNPolicy.name:
        vf_spec = EnvSpec(env.obs_space, ValueFunctionSpace)
        vfcn = RNNPolicy(vf_spec,
                         hidden_size=16,
                         num_recurrent_layers=1,
                         use_cuda=use_cuda)
    else:
        raise NotImplementedError

    # Create critic
    critic_hparam = dict(
        gamma=0.98,
        lamda=0.95,
        batch_size=32,
        lr=1e-3,
        standardize_adv=False,
    )
    critic = GAE(vfcn, **critic_hparam)

    # Common hyper-parameters
    common_hparam = dict(max_iter=2, min_rollouts=3, num_workers=1)
    # Add specific hyper parameters if any
    common_hparam.update(algo_hparam)

    # Create algorithm and train
    algo = algo(ex_dir, env, policy, critic, **common_hparam)
    algo.train()
    assert algo.curr_iter == algo.max_iter
def test_export_cpp(env, policy: Policy, tmpdir, file_type):
    # Generate scripted version (in double mode for CPP compatibility)
    scripted = policy.double().script()

    # Export
    export_file = osp.join(tmpdir, 'policy' + file_type)
    scripted.save(export_file)

    # Import again
    loaded = to.jit.load(export_file)

    # Compare a couple of inputs
    for i in range(50):
        obs = policy.env_spec.obs_space.sample_uniform()
        act_scripted = scripted(to.from_numpy(obs)).cpu().numpy()
        act_loaded = loaded(to.from_numpy(obs)).cpu().numpy()
        assert act_loaded == pytest.approx(act_scripted), f"Wrong action values on step #{i}"

    # Test after reset
    if hasattr(scripted, 'reset'):
        scripted.reset()
        loaded.reset()
        assert loaded.hidden.numpy() == pytest.approx(scripted.hidden.numpy()), "Wrong hidden state after reset"

        obs = policy.env_spec.obs_space.sample_uniform()
        act_scripted = scripted(to.from_numpy(obs)).numpy()
        act_loaded = loaded(to.from_numpy(obs)).numpy()
        assert act_loaded == pytest.approx(act_scripted), "Wrong action values after reset"
Exemple #7
0
def test_export_rcspysim(env: Env, policy: Policy, tmpdir: str):
    from rcsenv import ControlPolicy

    # Generate scripted version (double mode for CPP compatibility)
    scripted = policy.double().script()
    print(scripted.graph)

    # Export
    export_file = osp.join(tmpdir, "policy.pt")
    to.jit.save(scripted, export_file)

    # Import in C
    cpp = ControlPolicy("torch", export_file)

    # Compare a couple of inputs
    for _ in range(50):
        obs = policy.env_spec.obs_space.sample_uniform()
        obs = to.from_numpy(obs).to(dtype=to.double)
        act_script = scripted(obs).cpu().numpy()
        act_cpp = cpp(obs, policy.env_spec.act_space.flat_dim)
        assert act_cpp == pytest.approx(act_script)

    # Test after reset
    if hasattr(scripted, "reset"):
        scripted.reset()
        cpp.reset()
        obs = policy.env_spec.obs_space.sample_uniform()
        obs = to.from_numpy(obs).to(dtype=to.double)
        act_script = scripted(obs).cpu().numpy()
        act_cpp = cpp(obs, policy.env_spec.act_space.flat_dim)
        assert act_cpp == pytest.approx(act_script)
Exemple #8
0
def cpp_export(
    save_dir: pyrado.PathLike,
    policy: Policy,
    env: Optional[SimEnv] = None,
    policy_export_name: str = "policy_export",
    write_policy_node: bool = True,
    policy_node_name: str = "policy",
):
    """
    Convenience function to export the policy using PyTorch's scripting or tracing, and the experiment's XML
    configuration if the environment from RcsPySim.

    :param save_dir: directory to save in
    :param policy: (trained) policy
    :param env: environment the policy was trained in
    :param policy_export_name: name of the exported policy file without the file type ending
    :param write_policy_node: if `True`, write the PyTorch-based control policy into the experiment's XML configuration.
                              This requires the experiment's XML configuration to be exported beforehand.
    :param policy_node_name: name of the control policies node in the XML file, e.g. 'policy' or 'preStrikePolicy'
    """
    if not osp.isdir(save_dir):
        raise pyrado.PathErr(given=save_dir)
    if not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if not isinstance(policy_export_name, str):
        raise pyrado.TypeErr(given=policy_export_name, expected_type=str)

    # Use torch.jit.trace / torch.jit.script (the latter if recurrent) to generate a torch.jit.ScriptModule
    ts_module = policy.double().script(
    )  # can be evaluated like a regular PyTorch module

    # Serialize the script module to a file and save it in the same directory we loaded the policy from
    policy_export_file = osp.join(save_dir, f"{policy_export_name}.pt")
    ts_module.save(policy_export_file)  # former .zip, and before that .pth
    print_cbt(f"Exported the loaded policy to {policy_export_file}",
              "g",
              bright=True)

    # Export the experiment config for C++
    exp_export_file = osp.join(save_dir, "ex_config_export.xml")
    if env is not None and isinstance(inner_env(env), RcsSim):
        inner_env(env).save_config_xml(exp_export_file)
        print_cbt(f"Exported experiment configuration to {exp_export_file}",
                  "g",
                  bright=True)

    # Open the XML file again to add the policy node
    if write_policy_node and osp.isfile(exp_export_file):
        tree = et.parse(exp_export_file)
        root = tree.getroot()
        policy_node = et.Element(policy_node_name)
        policy_node.set("type", "torch")
        policy_node.set("file", f"{policy_export_name}.pt")
        root.append(policy_node)
        tree.write(exp_export_file)
        print_cbt(
            f"Added {policy_export_name}.pt to the experiment configuration.",
            "g")
Exemple #9
0
def plot_features(ro: StepSequence, policy: Policy):
    """
    Plot all features given the policy and the observation trajectories.

    :param policy: linear policy used during the rollout
    :param ro: input rollout
    """
    if not isinstance(policy, LinearPolicy):
        print_cbt(
            'Plotting of the feature values is only supports linear policies!',
            'y')
        return

    if hasattr(ro, 'observations'):
        # Use recorded time stamps if possible
        t = ro.env_infos.get('t', np.arange(0, ro.length)) if hasattr(
            ro, 'env_infos') else np.arange(0, ro.length)

        # Recover the features from the observations
        feat_vals = policy.eval_feats(to.from_numpy(ro.observations))
        dim_feat = range(feat_vals.shape[1])
        if len(dim_feat) <= 6:
            divisor = 2
        elif len(dim_feat) <= 12:
            divisor = 4
        else:
            divisor = 8
        num_cols = int(np.ceil(len(dim_feat) / divisor))
        num_rows = int(np.ceil(len(dim_feat) / num_cols))

        fig, axs = plt.subplots(num_rows,
                                num_cols,
                                figsize=(num_cols * 5, num_rows * 3),
                                constrained_layout=True)
        fig.suptitle('Feature values over Time')
        plt.subplots_adjust(hspace=.5)
        colors = plt.get_cmap('tab20')(np.linspace(0, 1, len(dim_feat)))

        if len(dim_feat) == 1:
            axs.plot(t,
                     feat_vals[:-1, dim_feat[0]],
                     label=_get_obs_label(ro, dim_feat[0]))
            axs.legend()
        else:
            for i in range(num_rows):
                for j in range(num_cols):
                    if j + i * num_cols < len(dim_feat):
                        # Omit the last observation for simplicity
                        axs[i, j].plot(t,
                                       feat_vals[:-1, j + i * num_cols],
                                       label=rf'$\phi_{j + i*num_cols}$',
                                       c=colors[j + i * num_cols])
                        axs[i, j].legend()
                    else:
                        # We might create more subplots than there are observations
                        pass
        plt.show()
Exemple #10
0
def plot_features(ro: StepSequence, policy: Policy):
    """
    Plot all features given the policy and the observation trajectories.

    :param policy: linear policy used during the rollout
    :param ro: input rollout
    """
    if not isinstance(policy, LinearPolicy):
        print_cbt(
            "Plotting of the feature values is only supports linear policies!",
            "r")
        return

    if hasattr(ro, "observations"):
        # Use recorded time stamps if possible
        t = getattr(ro, "time", np.arange(0, ro.length + 1))[:-1]

        # Recover the features from the observations
        feat_vals = policy.eval_feats(to.from_numpy(ro.observations))
        dim_feat = range(feat_vals.shape[1])
        if len(dim_feat) <= 6:
            divisor = 2
        elif len(dim_feat) <= 12:
            divisor = 4
        else:
            divisor = 8
        num_cols = int(np.ceil(len(dim_feat) / divisor))
        num_rows = int(np.ceil(len(dim_feat) / num_cols))

        fig, axs = plt.subplots(num_rows,
                                num_cols,
                                figsize=(num_cols * 5, num_rows * 3),
                                tight_layout=True)
        axs = np.atleast_2d(axs)
        axs = correct_atleast_2d(axs)
        fig.canvas.manager.set_window_title("Feature Values over Time")
        plt.subplots_adjust(hspace=0.5)
        colors = plt.get_cmap("tab20")(np.linspace(0, 1, len(dim_feat)))

        if len(dim_feat) == 1:
            axs[0, 0].plot(t,
                           feat_vals[:-1, dim_feat[0]],
                           label=_get_obs_label(ro, dim_feat[0]))
            axs[0, 0].legend()
        else:
            for i in range(num_rows):
                for j in range(num_cols):
                    if j + i * num_cols < len(dim_feat):
                        # Omit the last observation for simplicity
                        axs[i, j].plot(t,
                                       feat_vals[:-1, j + i * num_cols],
                                       c=colors[j + i * num_cols])
                        axs[i, j].set_ylabel(rf"$\phi_{{{j + i*num_cols}}}$")
                    else:
                        # We might create more subplots than there are observations
                        axs[i, j].remove()
Exemple #11
0
def test_script_nonrecurrent(env: Env, policy: Policy):
    # Generate scripted version
    scripted = policy.double().script()

    # Compare results
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg = policy(obs)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
Exemple #12
0
def test_recurrent_policy_one_step(env: Env, policy: Policy):
    hid = policy.init_hidden()
    obs = env.obs_space.sample_uniform()
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, out2, hid = policy(obs, hid)
        assert isinstance(out2, to.Tensor)
    else:
        act, hid = policy(obs, hid)
    assert isinstance(act, to.Tensor) and isinstance(hid, to.Tensor)
Exemple #13
0
def test_parallel_sampling_deterministic_smoke_test_w_min_steps(
        tmpdir_factory, env: SimEnv, policy: Policy, algo, min_rollouts: int,
        min_steps: int):
    env.max_steps = 20

    seeds = (0, 1)
    nums_workers = (1, 2, 4)

    logging_results = []
    rollout_results: List[List[List[List[StepSequence]]]] = []
    for seed in seeds:
        logging_results.append((seed, []))
        rollout_results.append([])
        for num_workers in nums_workers:
            pyrado.set_seed(seed)
            policy.init_param(None)
            ex_dir = str(
                tmpdir_factory.mktemp(
                    f"seed={seed}-num_workers={num_workers}"))
            set_log_prefix_dir(ex_dir)
            vfcn = FNN(input_size=env.obs_space.flat_dim,
                       output_size=1,
                       hidden_sizes=[16, 16],
                       hidden_nonlin=to.tanh)
            critic = GAE(vfcn,
                         gamma=0.98,
                         lamda=0.95,
                         batch_size=32,
                         lr=1e-3,
                         standardize_adv=False)
            alg = algo(
                ex_dir,
                env,
                policy,
                critic,
                max_iter=3,
                min_rollouts=min_rollouts,
                min_steps=min_steps * env.max_steps,
                num_workers=num_workers,
            )
            alg.sampler = RolloutSavingWrapper(alg.sampler)
            alg.train()
            with open(f"{ex_dir}/progress.csv") as f:
                logging_results[-1][1].append(str(f.read()))
            rollout_results[-1].append(alg.sampler.rollouts)

    # Test that the observations for all number of workers are equal.
    for rollouts in rollout_results:
        for ros_a, ros_b in [(a, b) for a in rollouts for b in rollouts]:
            assert len(ros_a) == len(ros_b)
            for ro_a, ro_b in zip(ros_a, ros_b):
                assert len(ro_a) == len(ro_b)
                for r_a, r_b in zip(ro_a, ro_b):
                    assert r_a.observations == pytest.approx(r_b.observations)

    # Test that different seeds actually produce different results.
    for results_a, results_b in [(a, b) for seed_a, a in logging_results
                                 for seed_b, b in logging_results
                                 if seed_a != seed_b]:
        for result_a, result_b in [(a, b) for a in results_a for b in results_b
                                   if a is not b]:
            assert result_a != result_b

    # Test that same seeds produce same results.
    for _, results in logging_results:
        for result_a, result_b in [(a, b) for a in results for b in results]:
            assert result_a == result_b
Exemple #14
0
def draw_policy_params(policy: Policy,
                       env_spec: EnvSpec,
                       cmap_name: str = 'RdBu',
                       ax_hm: plt.Axes = None,
                       annotate: bool = True,
                       annotation_valfmt: str = '{x:.2f}',
                       colorbar_label: str = '',
                       x_label: str = None,
                       y_label: str = None,
                       ) -> plt.Figure:
    """
    Plot the weights and biases as images, and a color bar.

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param policy: policy to visualize
    :param env_spec: environment specification
    :param cmap_name: name of the color map, e.g. 'inferno', 'RdBu', or 'viridis'
    :param ax_hm: axis to draw the heat map onto, if equal to None a new figure is opened
    :param annotate: select if the heat map should be annotated
    :param annotation_valfmt: format of the annotations inside the heat map, irrelevant if annotate = False
    :param colorbar_label: label for the color bar
    :param x_label: label for the x axis
    :param y_label: label for the y axis
    :return: handles to figures
    """
    if not isinstance(policy, nn.Module):
        raise pyrado.TypeErr(given=policy, expected_type=nn.Module)
    cmap = plt.get_cmap(cmap_name)

    # Create axes and subplots depending on the NN structure
    num_rows = len(list(policy.parameters()))
    fig = plt.figure(figsize=(14, 10), tight_layout=False)
    gs = fig.add_gridspec(num_rows, 2, width_ratios=[14, 1])  # right column is the color bar
    ax_cb = fig.add_subplot(gs[:, 1])

    # Accumulative norm for the colors
    norm = AccNorm()

    for i, (name, param) in enumerate(policy.named_parameters()):
        # Create current axis
        ax = plt.subplot(gs[i, 0])
        ax.set_title(name.replace('_', r'\_'))

        # Convert the data and plot the image with the colors proportional to the parameters
        if param.ndim == 3:
            # For example convolution layers
            param = param.flatten(0)
            print_cbt(f'Flattened the first dimension of the {name} parameter tensor.', 'y')
        data = np.atleast_2d(param.detach().cpu().numpy())

        img = plt.imshow(data, cmap=cmap, norm=norm, aspect='auto', origin='lower')

        if annotate:
            _annotate_img(
                img,
                thold_lo=0.75*min(policy.param_values).detach().cpu().numpy(),
                thold_up=0.75*max(policy.param_values).detach().cpu().numpy(),
                valfmt=annotation_valfmt
            )

        # Prepare the ticks
        if isinstance(policy, ADNPolicy):
            if name == 'obs_layer.weight':
                ax.set_xticks(np.arange(env_spec.obs_space.flat_dim))
                ax.set_yticks(np.arange(env_spec.act_space.flat_dim))
                ax.set_xticklabels(env_spec.obs_space.labels)
                ax.set_yticklabels(env_spec.act_space.labels)
            elif name in ['obs_layer.bias', 'nonlin_layer.log_weight', 'nonlin_layer.bias']:
                ax.set_xticks(np.arange(env_spec.act_space.flat_dim))
                ax.set_xticklabels(env_spec.act_space.labels)
                ax.yaxis.set_major_locator(ticker.NullLocator())
                ax.yaxis.set_minor_formatter(ticker.NullFormatter())
            elif name == 'prev_act_layer.weight':
                ax.set_xticks(np.arange(env_spec.act_space.flat_dim))
                ax.set_yticks(np.arange(env_spec.act_space.flat_dim))
                ax.set_xticklabels(env_spec.act_space.labels)
                ax.set_yticklabels(env_spec.act_space.labels)
            elif name in ['_log_tau', '_log_kappa', '_log_capacity']:
                ax.xaxis.set_major_locator(ticker.NullLocator())
                ax.yaxis.set_major_locator(ticker.NullLocator())
                ax.xaxis.set_minor_formatter(ticker.NullFormatter())
                ax.yaxis.set_minor_formatter(ticker.NullFormatter())
            else:
                ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        elif isinstance(policy, NFPolicy):
            if name == 'obs_layer.weight':
                ax.set_xticks(np.arange(env_spec.obs_space.flat_dim))
                ax.yaxis.set_major_locator(ticker.NullLocator())
                ax.set_xticklabels(env_spec.obs_space.labels)
                ax.yaxis.set_minor_formatter(ticker.NullFormatter())
            elif name in ['_log_tau', '_log_kappa', '_potentials_init', 'resting_level', 'obs_layer.bias',
                          'conv_layer.weight', 'nonlin_layer.log_weight', 'nonlin_layer.bias']:
                ax.xaxis.set_major_locator(ticker.NullLocator())
                ax.yaxis.set_major_locator(ticker.NullLocator())
                ax.xaxis.set_minor_formatter(ticker.NullFormatter())
                ax.yaxis.set_minor_formatter(ticker.NullFormatter())
            elif name == 'act_layer.weight':
                ax.xaxis.set_major_locator(ticker.NullLocator())
                ax.set_yticks(np.arange(env_spec.act_space.flat_dim))
                ax.xaxis.set_minor_formatter(ticker.NullFormatter())
                ax.set_yticklabels(env_spec.act_space.labels)
            else:
                ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        # Add the color bar (call this within the loop to make the AccNorm scan every image)
        colorbar.ColorbarBase(ax_cb, cmap=cmap, norm=norm, label=colorbar_label)

    # Increase the vertical white spaces between the subplots
    plt.subplots_adjust(hspace=.7, wspace=0.1)

    # Set the labels
    if x_label is not None:
        ax_hm.set_xlabel(x_label)
    if y_label is not None:
        ax_hm.set_ylabel(y_label)

    return fig
Exemple #15
0
def test_time_policy_one_step(env: Env, policy: Policy):
    policy.reset()
    obs = env.obs_space.sample_uniform()
    obs = to.from_numpy(obs)
    act = policy(obs)
    assert isinstance(act, to.Tensor)
Exemple #16
0
def test_parameterized_policies_init_param(env: Env, policy: Policy):
    some_values = to.ones_like(policy.param_values)
    policy.init_param(some_values)
    to.testing.assert_allclose(policy.param_values, some_values)