def test_get_actions(self, batch_size, hidden_sizes): """Test Tanh Gaussian Policy get actions function.""" env_spec = MetaRLEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim act_dim = env_spec.action_space.flat_dim obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) init_std = 2. policy = TanhGaussianMLPPolicy(env_spec=env_spec, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) expected_mean = torch.full([batch_size, act_dim], 1.0) action, prob = policy.get_actions(obs) assert np.allclose(prob['mean'], expected_mean.numpy(), rtol=1e-3) assert action.shape == (batch_size, act_dim)
def test_is_pickleable(self, batch_size, hidden_sizes): """Test if policy is unchanged after pickling.""" env_spec = MetaRLEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) init_std = 2. policy = TanhGaussianMLPPolicy(env_spec=env_spec, hidden_sizes=hidden_sizes, init_std=init_std, hidden_nonlinearity=None, std_parameterization='exp', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output1_action, output1_prob = policy.get_actions(obs) p = pickle.dumps(policy) policy_pickled = pickle.loads(p) output2_action, output2_prob = policy_pickled.get_actions(obs) assert np.allclose(output2_prob['mean'], output1_prob['mean'], rtol=1e-3) assert output1_action.shape == output2_action.shape