def test_is_pickleable(self, obs_dim, action_dim):
        env = MetaRLEnv(
            DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalMLPPolicy(env_spec=env.spec)

        policy.build(obs_var)
        obs = env.reset()

        with tf.compat.v1.variable_scope(
                'CategoricalMLPPolicy/CategoricalMLPModel', reuse=True):
            bias = tf.compat.v1.get_variable('mlp/hidden_0/bias')
        # assign it to all one
        bias.load(tf.ones_like(bias).eval())
        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            output2 = sess.run(
                [policy_pickled.distribution.probs],
                feed_dict={policy_pickled.model.input: [[obs.flatten()]]})
            assert np.array_equal(output1, output2)
    def test_categorial_mlp_policy(self):
        categorical_mlp_policy = CategoricalMLPPolicy(env_spec=self.env,
                                                      hidden_sizes=(1, ))
        self.sess.run(tf.compat.v1.global_variables_initializer())
        categorical_mlp_policy.build(self.obs_var)

        obs = self.env.observation_space.high
        assert categorical_mlp_policy.get_action(obs)
 def test_get_regularizable_vars(self, obs_dim, action_dim):
     env = MetaRLEnv(
         DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
     obs_var = tf.compat.v1.placeholder(
         tf.float32,
         shape=[None, None, env.observation_space.flat_dim],
         name='obs')
     policy = CategoricalMLPPolicy(env_spec=env.spec)
     policy.build(obs_var)
     reg_vars = policy.get_regularizable_vars()
     assert len(reg_vars) == 2
     for var in reg_vars:
         assert ('bias' not in var.name) and ('output' not in var.name)
    def test_get_action(self, obs_dim, action_dim):
        env = MetaRLEnv(
            DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalMLPPolicy(env_spec=env.spec)

        policy.build(obs_var)
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions(
            [obs.flatten(), obs.flatten(),
             obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)