def test_get_action(self, mock_normal, obs_dim, action_dim, hidden_dim):
        mock_normal.return_value = 0.5
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'gaussian_lstm_policy.GaussianLSTMModel'),
                        new=SimpleGaussianLSTMModel):
            policy = GaussianLSTMPolicy(env_spec=env.spec,
                                        state_include_action=False)
        expected_action = np.full(action_dim, 0.5 * np.exp(0.5) + 0.5)

        policy.reset()
        obs = env.reset()

        action, agent_info = policy.get_action(obs)
        assert env.action_space.contains(action)
        assert np.allclose(action,
                           np.full(action_dim, expected_action),
                           atol=1e-6)

        expected_mean = np.full(action_dim, 0.5)
        assert np.array_equal(agent_info['mean'], expected_mean)
        expected_log_std = np.full(action_dim, 0.5)
        assert np.array_equal(agent_info['log_std'], expected_log_std)

        actions, agent_infos = policy.get_actions([obs])
        for action, mean, log_std in zip(actions, agent_infos['mean'],
                                         agent_infos['log_std']):
            assert env.action_space.contains(action)
            assert np.allclose(action,
                               np.full(action_dim, expected_action),
                               atol=1e-6)
            assert np.array_equal(mean, expected_mean)
            assert np.array_equal(log_std, expected_log_std)
    def test_get_action(self, obs_dim, action_dim, hidden_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = GaussianLSTMPolicy(env_spec=env.spec,
                                    hidden_dim=hidden_dim,
                                    state_include_action=False)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)