def test_get_action_state_include_action(self, obs_dim, action_dim,
                                             hidden_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[
                None, None,
                env.observation_space.flat_dim + np.prod(action_dim)
            ],
            name='obs')
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_dim=hidden_dim,
                                   state_include_action=True)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        policy.reset()

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
    def test_gaussian_gru_policy(self):
        gaussian_gru_policy = GaussianGRUPolicy(env_spec=self.env,
                                                hidden_dim=1,
                                                state_include_action=False)
        self.sess.run(tf.compat.v1.global_variables_initializer())

        gaussian_gru_policy.build(self.obs_var)
        gaussian_gru_policy.reset()

        obs = self.env.observation_space.high
        assert gaussian_gru_policy.get_action(obs)
    def test_is_pickleable(self):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=(1, ), action_dim=(1, )))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   state_include_action=False)

        policy.build(obs_var)
        env.reset()
        obs = env.reset()
        with tf.compat.v1.variable_scope('GaussianGRUPolicy/GaussianGRUModel',
                                         reuse=True):
            param = tf.compat.v1.get_variable(
                'dist_params/log_std_param/parameter')
        # assign it to all one
        param.load(tf.ones_like(param).eval())

        output1 = self.sess.run(
            [policy.distribution.loc,
             policy.distribution.stddev()],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            # yapf: disable
            output2 = sess.run(
                [
                    policy_pickled.distribution.loc,
                    policy_pickled.distribution.stddev()
                ],
                feed_dict={
                    policy_pickled.model.input: [[obs.flatten()],
                                                 [obs.flatten()]]
                })
            assert np.array_equal(output1, output2)