def test_dist_info_sym_include_action(self, obs_dim, action_dim, hidden_dim): env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_ph = tf.compat.v1.placeholder( tf.float32, shape=(None, None, env.observation_space.flat_dim)) with mock.patch(('metarl.tf.policies.' 'gaussian_lstm_policy.GaussianLSTMModel'), new=SimpleGaussianLSTMModel): policy = GaussianLSTMPolicy(env_spec=env.spec, state_include_action=True) policy.reset() obs = env.reset() dist_sym = policy.dist_info_sym( obs_var=obs_ph, state_info_vars={'prev_action': np.zeros((2, 1) + action_dim)}, name='p2_sym') dist = self.sess.run( dist_sym, feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]}) assert np.array_equal(dist['mean'], np.full((2, 1) + action_dim, 0.5)) assert np.array_equal(dist['log_std'], np.full((2, 1) + action_dim, 0.5))
def test_get_action(self, mock_normal, obs_dim, action_dim, hidden_dim): mock_normal.return_value = 0.5 env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) with mock.patch(('metarl.tf.policies.' 'gaussian_lstm_policy.GaussianLSTMModel'), new=SimpleGaussianLSTMModel): policy = GaussianLSTMPolicy(env_spec=env.spec, state_include_action=False) expected_action = np.full(action_dim, 0.5 * np.exp(0.5) + 0.5) policy.reset() obs = env.reset() action, agent_info = policy.get_action(obs) assert env.action_space.contains(action) assert np.allclose(action, np.full(action_dim, expected_action), atol=1e-6) expected_mean = np.full(action_dim, 0.5) assert np.array_equal(agent_info['mean'], expected_mean) expected_log_std = np.full(action_dim, 0.5) assert np.array_equal(agent_info['log_std'], expected_log_std) actions, agent_infos = policy.get_actions([obs]) for action, mean, log_std in zip(actions, agent_infos['mean'], agent_infos['log_std']): assert env.action_space.contains(action) assert np.allclose(action, np.full(action_dim, expected_action), atol=1e-6) assert np.array_equal(mean, expected_mean) assert np.array_equal(log_std, expected_log_std)
def test_gaussian_lstm_policy(self): gaussian_lstm_policy = GaussianLSTMPolicy(env_spec=self.env, hidden_dim=1) self.sess.run(tf.compat.v1.global_variables_initializer()) gaussian_lstm_policy.reset() obs = self.env.observation_space.high assert gaussian_lstm_policy.get_action(obs)
def test_get_action(self, obs_dim, action_dim, hidden_dim): env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = GaussianLSTMPolicy(env_spec=env.spec, hidden_dim=hidden_dim, state_include_action=False) policy.build(obs_var) policy.reset() obs = env.reset() action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs.flatten()]) for action in actions: assert env.action_space.contains(action)
def test_dist_info_sym_wrong_input(self): env = TfEnv(DummyBoxEnv(obs_dim=(1, ), action_dim=(1, ))) obs_ph = tf.compat.v1.placeholder( tf.float32, shape=(None, None, env.observation_space.flat_dim)) with mock.patch(('metarl.tf.policies.' 'gaussian_lstm_policy.GaussianLSTMModel'), new=SimpleGaussianLSTMModel): policy = GaussianLSTMPolicy(env_spec=env.spec, state_include_action=True) policy.reset() obs = env.reset() policy.dist_info_sym( obs_var=obs_ph, state_info_vars={'prev_action': np.zeros((3, 1, 1))}, name='p2_sym') # observation batch size = 2 but prev_action batch size = 3 with pytest.raises(tf.errors.InvalidArgumentError): self.sess.run( policy.model.networks['p2_sym'].input, feed_dict={obs_ph: [[obs.flatten()], [obs.flatten()]]})