Ejemplo n.º 1
0
    def test_get_actions(self, batch_size, hidden_sizes):
        """Test Tanh Gaussian Policy get actions function."""
        env_spec = MetaRLEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        act_dim = env_spec.action_space.flat_dim
        obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
        init_std = 2.

        policy = TanhGaussianMLPPolicy(env_spec=env_spec,
                                       hidden_sizes=hidden_sizes,
                                       init_std=init_std,
                                       hidden_nonlinearity=None,
                                       std_parameterization='exp',
                                       hidden_w_init=nn.init.ones_,
                                       output_w_init=nn.init.ones_)

        expected_mean = torch.full([batch_size, act_dim], 1.0)
        action, prob = policy.get_actions(obs)
        assert np.allclose(prob['mean'], expected_mean.numpy(), rtol=1e-3)
        assert action.shape == (batch_size, act_dim)
Ejemplo n.º 2
0
    def test_is_pickleable(self, batch_size, hidden_sizes):
        """Test if policy is unchanged after pickling."""
        env_spec = MetaRLEnv(DummyBoxEnv())
        obs_dim = env_spec.observation_space.flat_dim
        obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
        init_std = 2.

        policy = TanhGaussianMLPPolicy(env_spec=env_spec,
                                       hidden_sizes=hidden_sizes,
                                       init_std=init_std,
                                       hidden_nonlinearity=None,
                                       std_parameterization='exp',
                                       hidden_w_init=nn.init.ones_,
                                       output_w_init=nn.init.ones_)

        output1_action, output1_prob = policy.get_actions(obs)

        p = pickle.dumps(policy)
        policy_pickled = pickle.loads(p)
        output2_action, output2_prob = policy_pickled.get_actions(obs)
        assert np.allclose(output2_prob['mean'],
                           output1_prob['mean'],
                           rtol=1e-3)
        assert output1_action.shape == output2_action.shape