def test_is_pickleable(self, obs_dim, action_dim): env = MetaRLEnv( DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = CategoricalMLPPolicy(env_spec=env.spec) policy.build(obs_var) obs = env.reset() with tf.compat.v1.variable_scope( 'CategoricalMLPPolicy/CategoricalMLPModel', reuse=True): bias = tf.compat.v1.get_variable('mlp/hidden_0/bias') # assign it to all one bias.load(tf.ones_like(bias).eval()) output1 = self.sess.run( [policy.distribution.probs], feed_dict={policy.model.input: [[obs.flatten()]]}) p = pickle.dumps(policy) with tf.compat.v1.Session(graph=tf.Graph()) as sess: policy_pickled = pickle.loads(p) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy_pickled.build(obs_var) output2 = sess.run( [policy_pickled.distribution.probs], feed_dict={policy_pickled.model.input: [[obs.flatten()]]}) assert np.array_equal(output1, output2)
def test_categorial_mlp_policy(self): categorical_mlp_policy = CategoricalMLPPolicy(env_spec=self.env, hidden_sizes=(1, )) self.sess.run(tf.compat.v1.global_variables_initializer()) categorical_mlp_policy.build(self.obs_var) obs = self.env.observation_space.high assert categorical_mlp_policy.get_action(obs)
def test_get_regularizable_vars(self, obs_dim, action_dim): env = MetaRLEnv( DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = CategoricalMLPPolicy(env_spec=env.spec) policy.build(obs_var) reg_vars = policy.get_regularizable_vars() assert len(reg_vars) == 2 for var in reg_vars: assert ('bias' not in var.name) and ('output' not in var.name)
def test_get_action(self, obs_dim, action_dim): env = MetaRLEnv( DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim)) obs_var = tf.compat.v1.placeholder( tf.float32, shape=[None, None, env.observation_space.flat_dim], name='obs') policy = CategoricalMLPPolicy(env_spec=env.spec) policy.build(obs_var) obs = env.reset() action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(action) actions, _ = policy.get_actions( [obs.flatten(), obs.flatten(), obs.flatten()]) for action in actions: assert env.action_space.contains(action)