def test_is_pickleable(self): env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = CategoricalLSTMPolicy2(env_spec=env.spec, state_include_action=False) policy.build(obs_var) policy.reset() obs = env.reset() policy.model._lstm_cell.weights[0].load( tf.ones_like(policy.model._lstm_cell.weights[0]).eval()) output1 = self.sess.run( [policy.distribution.logits], feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]}) p = pickle.dumps(policy) with tf.compat.v1.Session(graph=tf.Graph()) as sess: policy_pickled = pickle.loads(p) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy_pickled.build(obs_var) output2 = sess.run([policy_pickled.distribution.logits], feed_dict={ policy_pickled.model.input: [[obs.flatten()], [obs.flatten()]] }) # noqa: E126 assert np.array_equal(output1, output2)
def test_get_action(self, obs_dim, action_dim, hidden_dim): env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = CategoricalLSTMPolicy2(env_spec=env.spec, hidden_dim=hidden_dim, state_include_action=False) policy.build(obs_var) policy.reset() obs = env.reset() action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions([obs.flatten()]) for action in actions: assert env.action_space.contains(action)
def test_invalid_env(self): env = TfEnv(DummyBoxEnv()) with pytest.raises(ValueError): CategoricalLSTMPolicy2(env_spec=env.spec)
def test_clone(self): env = TfEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4)) policy = CategoricalLSTMPolicy2(env_spec=env.spec) policy_clone = policy.clone('CategoricalLSTMPolicyClone') assert policy.env_spec == policy_clone.env_spec
def test_state_info_specs_with_state_include_action(self): env = TfEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4)) policy = CategoricalLSTMPolicy2(env_spec=env.spec, state_include_action=True) assert policy.state_info_specs == [('prev_action', (4, ))]
def test_state_info_specs(self): env = TfEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4)) policy = CategoricalLSTMPolicy2(env_spec=env.spec, state_include_action=False) assert policy.state_info_specs == []