def test_obs_unflattened(self, action_dim, kernel_sizes, hidden_channels,
                             strides, paddings):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))
        env = self._initialize_obs_env(env)

        env.reset()
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        obs = env.observation_space.sample()
        action, _ = policy.get_action(env.observation_space.flatten(obs))
        env.step(action)
示例#2
0
def test_visualization():
    inner_env = gym.make('MountainCar-v0')
    env = GymEnv(inner_env)

    env.reset()
    env.visualize()
    assert inner_env.metadata['render.modes'] == env.render_modes
    env.step(env.action_space.sample())
示例#3
0
def test_inconsistent_env_infos():
    env = GymEnv('MountainCar-v0')
    env.reset()
    env._env_info = {'k1': 'v1', 'k2': 'v2'}
    with pytest.raises(RuntimeError,
                       match='GymEnv outputs inconsistent env_info keys.'):
        env.step(env.action_space.sample())
    # check that order of keys don't matter for equality
    assert env._env_info.keys() == {'k2': 'v2', 'k1': 'v1'}.keys()
示例#4
0
class TestQfDerivedPolicy(TfGraphTestCase):

    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(DummyDiscreteEnv())
        self.qf = SimpleQFunction(self.env.spec)
        self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                             qf=self.qf)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    def test_discrete_qf_argmax_policy(self):
        obs = self.env.step(1).observation
        action, _ = self.policy.get_action(obs)
        assert self.env.action_space.contains(action)
        actions, _ = self.policy.get_actions([obs])
        for action in actions:
            assert self.env.action_space.contains(action)

    def test_get_param(self):
        with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        assert self.policy.get_param_values() == return_var.eval()

    def test_is_pickleable(self):
        with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())
        obs = self.env.step(1).observation
        action1, _ = self.policy.get_action(obs)

        p = pickle.dumps(self.policy)
        with tf.compat.v1.Session(graph=tf.Graph()):
            policy_pickled = pickle.loads(p)
            action2, _ = policy_pickled.get_action(obs)
            assert action1 == action2

    def test_does_not_support_dict_obs_space(self):
        """Test that policy raises error if passed a dict obs space."""
        env = GymEnv(DummyDictEnv(act_space_type='discrete'))
        with pytest.raises(ValueError):
            qf = SimpleQFunction(env.spec,
                                 name='does_not_support_dict_obs_space')
            DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)

    def test_invalid_action_spaces(self):
        """Test that policy raises error if passed a dict obs space."""
        env = GymEnv(DummyDictEnv(act_space_type='box'))
        with pytest.raises(ValueError):
            qf = SimpleQFunction(env.spec)
            DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
示例#5
0
    def test_get_action(self, obs_dim, task_num, latent_dim, action_dim):
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        embedding_spec = InOutSpec(
            input_space=akro.Box(low=np.zeros(task_num),
                                 high=np.ones(task_num)),
            output_space=akro.Box(low=np.zeros(latent_dim),
                                  high=np.ones(latent_dim)))
        encoder = GaussianMLPEncoder(embedding_spec)
        policy = GaussianMLPTaskEmbeddingPolicy(env_spec=env.spec,
                                                encoder=encoder)

        env.reset()
        obs = env.step(1).observation
        latent = np.random.random((latent_dim, ))
        task = np.zeros(task_num)
        task[0] = 1

        action1, _ = policy.get_action_given_latent(obs, latent)
        action2, _ = policy.get_action_given_task(obs, task)
        action3, _ = policy.get_action(np.concatenate([obs.flatten(), task]))

        assert env.action_space.contains(action1)
        assert env.action_space.contains(action2)
        assert env.action_space.contains(action3)

        obses, latents, tasks = [obs] * 3, [latent] * 3, [task] * 3
        aug_obses = [np.concatenate([obs.flatten(), task])] * 3
        action1n, _ = policy.get_actions_given_latents(obses, latents)
        action2n, _ = policy.get_actions_given_tasks(obses, tasks)
        action3n, _ = policy.get_actions(aug_obses)

        for action in chain(action1n, action2n, action3n):
            assert env.action_space.contains(action)
    def test_is_pickleable(self):
        env = GymEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=((3, (32, 32)), ),
                                      strides=(1, ),
                                      padding='SAME',
                                      hidden_sizes=(4, ))

        env.reset()
        obs = env.step(1).observation

        with tf.compat.v1.variable_scope('CategoricalCNNPolicy', reuse=True):
            cnn_bias = tf.compat.v1.get_variable('CNNModel/cnn/h0/bias')
            bias = tf.compat.v1.get_variable('MLPModel/mlp/hidden_0/bias')

        cnn_bias.load(tf.ones_like(cnn_bias).eval())
        bias.load(tf.ones_like(bias).eval())

        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, None) +
                                               policy.input_dim)
        dist_sym = policy.build(state_input, name='dist_sym').dist
        output1 = self.sess.run(dist_sym.probs,
                                feed_dict={state_input: [[obs]]})
        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            state_input = tf.compat.v1.placeholder(tf.float32,
                                                   shape=(None, None) +
                                                   policy.input_dim)
            dist_sym = policy_pickled.build(state_input, name='dist_sym').dist
            output2 = sess.run(dist_sym.probs,
                               feed_dict={state_input: [[obs]]})
            assert np.array_equal(output1, output2)
示例#7
0
    def test_is_pickleable(self, obs_dim, action_dim):
        """Test if ContinuousMLPPolicy is pickleable"""
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = ContinuousMLPPolicy(env_spec=env.spec)
        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, np.prod(obs_dim)))
        outputs = policy.build(state_input, name='policy')
        env.reset()
        obs = env.step(1).observation

        with tf.compat.v1.variable_scope('ContinuousMLPPolicy', reuse=True):
            bias = tf.compat.v1.get_variable('mlp/hidden_0/bias')
        # assign it to all one
        bias.load(tf.ones_like(bias).eval())
        output1 = self.sess.run([outputs],
                                feed_dict={state_input: [obs.flatten()]})

        p = pickle.dumps(policy)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            state_input = tf.compat.v1.placeholder(tf.float32,
                                                   shape=(None,
                                                          np.prod(obs_dim)))
            outputs = policy_pickled.build(state_input, name='policy')
            output2 = sess.run([outputs],
                               feed_dict={state_input: [obs.flatten()]})
            assert np.array_equal(output1, output2)
    def test_is_pickleable(self, kernel_sizes, hidden_channels, strides,
                           paddings):
        """Test if policy is pickable."""
        env = GymEnv(DummyDiscretePixelEnv())
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        output_action_1, _ = policy.get_action(obs.flatten())

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(int(output_action_1[0]))
        assert env.action_space.contains(int(output_action_2[0]))
        assert output_action_1.shape == output_action_2.shape
    def test_is_pickleable(self, action_dim, kernel_sizes, hidden_channels,
                           strides, paddings):
        """Test if policy is pickable."""
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))

        env = self._initialize_obs_env(env)
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        output_action_1, _ = policy.get_action(obs.flatten())

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(int(output_action_1[0]))
        assert env.action_space.contains(int(output_action_2[0]))
        assert output_action_1.shape == output_action_2.shape
 def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides,
                          hidden_sizes):
     """Test if a flattened image obs is passed to get_action
        then it is unflattened.
     """
     env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
     env.reset()
     policy = CategoricalCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     obs = env.observation_space.sample()
     action, _ = policy.get_action(env.observation_space.flatten(obs))
     env.step(action)
    def test_is_pickleable(self, obs_dim, embedding_dim):
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=embedding_dim))
        embedding_spec = InOutSpec(input_space=env.spec.observation_space,
                                   output_space=env.spec.action_space)
        embedding = GaussianMLPEncoder(embedding_spec)

        env.reset()
        obs = env.step(env.action_space.sample()).observation
        obs_dim = env.spec.observation_space.flat_dim

        with tf.compat.v1.variable_scope('GaussianMLPEncoder/GaussianMLPModel',
                                         reuse=True):
            bias = tf.compat.v1.get_variable(
                'dist_params/mean_network/hidden_0/bias')
        # assign it to all one
        bias.load(tf.ones_like(bias).eval())
        output1 = self.sess.run(
            [embedding.distribution.loc,
             embedding.distribution.stddev()],
            feed_dict={embedding.model.input: [[obs.flatten()]]})

        p = pickle.dumps(embedding)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            embedding_pickled = pickle.loads(p)

            output2 = sess.run(
                [
                    embedding_pickled.distribution.loc,
                    embedding_pickled.distribution.stddev()
                ],
                feed_dict={embedding_pickled.model.input: [[obs.flatten()]]})
            assert np.array_equal(output1, output2)
    def test_get_actions(self, action_dim, kernel_sizes, hidden_channels,
                         strides, paddings):
        """Test get_actions function."""
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))

        env = self._initialize_obs_env(env)
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        env.reset()
        obs = env.step(1).observation

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(int(action[0]))
            assert env.action_space.n == action_dim
示例#13
0
    def test_get_qval_max_pooling(self, filters, strides, pool_strides,
                                  pool_shapes):
        env = GymEnv(DummyDiscretePixelEnv())
        obs = env.reset()[0]

        with mock.patch(('garage.tf.models.'
                         'cnn_mlp_merge_model.CNNModelWithMaxPooling'),
                        new=SimpleCNNModelWithMaxPooling):
            with mock.patch(('garage.tf.models.'
                             'cnn_mlp_merge_model.MLPMergeModel'),
                            new=SimpleMLPMergeModel):
                qf = ContinuousCNNQFunction(env_spec=env.spec,
                                            filters=filters,
                                            strides=strides,
                                            max_pooling=True,
                                            pool_strides=pool_strides,
                                            pool_shapes=pool_shapes)

        action_dim = env.action_space.shape

        obs = env.step(1).observation

        act = np.full(action_dim, 0.5)
        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])

        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])

        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_output_shape(self, obs_dim, action_dim):
        env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        qf = DiscreteMLPDuelingQFunction(env_spec=env.spec)
        env.reset()
        obs = env.step(1).observation

        outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [obs]})
        assert outputs.shape == (1, action_dim)
示例#15
0
def test_time_limit_env():
    garage_env = GymEnv('Pendulum-v0', max_episode_length=200)
    garage_env._env._max_episode_steps = 200
    garage_env.reset()
    for _ in range(200):
        es = garage_env.step(garage_env.spec.action_space.sample())
    assert es.timeout and es.env_info['TimeLimit.truncated']
    assert es.env_info['GymEnv.TimeLimitTerminated']
示例#16
0
def step_bullet_kuka_env(n_steps=1000):
    """Load, step, and visualize a Bullet Kuka environment.

    Args:
        n_steps (int): number of steps to run.

    """
    # Construct the environment
    env = GymEnv(gym.make('KukaBulletEnv-v0', renders=True, isDiscrete=True))

    # Reset the environment and launch the viewer
    env.reset()
    env.visualize()

    step_count = 0
    es = env.step(env.action_space.sample())
    while not es.last and step_count < n_steps:
        es = env.step(env.action_space.sample())
        step_count += 1
示例#17
0
    def test_output_shape(self, obs_dim, action_dim):
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs = env.step(1).observation
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        outputs = qf.get_qval([obs], [act])
        assert outputs.shape == (1, 1)
示例#18
0
def test_done_resets_step_cnt():
    env = GymEnv('MountainCar-v0')
    max_episode_length = env.spec.max_episode_length

    env.reset()
    for _ in range(max_episode_length):
        es = env.step(env.action_space.sample())
        if es.last:
            break
    assert env._step_cnt is None
    def test_obs_unflattened(self, kernel_sizes, hidden_channels, strides,
                             paddings):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        env = GymEnv(DummyDiscretePixelEnv())
        env.reset()
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        obs = env.observation_space.sample()
        action, _ = policy.get_action(env.observation_space.flatten(obs))
        env.step(action[0])
示例#20
0
 def test_unflattened_input(self):
     env = GymEnv(DummyBoxEnv(obs_dim=(2, 2)))
     cmb = ContinuousMLPBaseline(env_spec=env.spec)
     env.reset()
     es = env.step(1)
     obs, rewards = es.observation, es.reward
     train_paths = [{'observations': [obs], 'returns': [rewards]}]
     cmb.fit(train_paths)
     paths = {'observations': [obs]}
     prediction = cmb.predict(paths)
     assert np.allclose(0., prediction)
class TestCategoricalCNNPolicyImageObs(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    @pytest.mark.parametrize('filters, strides, padding, hidden_sizes', [
        (((3, (32, 32)), ), (1, ), 'VALID', (4, )),
    ])
    def test_obs_unflattened(self, filters, strides, padding, hidden_sizes):
        self.policy = CategoricalCNNPolicy(env_spec=self.env.spec,
                                           filters=filters,
                                           strides=strides,
                                           padding=padding,
                                           hidden_sizes=hidden_sizes)
        obs = self.env.observation_space.sample()
        action, _ = self.policy.get_action(
            self.env.observation_space.flatten(obs))
        self.env.step(action)
示例#22
0
class TestQfDerivedPolicyImageObs(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()),
                          is_image=True)
        self.qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                       filters=((1, (1, 1)), ),
                                       strides=(1, ),
                                       dueling=False)
        self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                             qf=self.qf)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    def test_obs_unflattened(self):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        obs = self.env.observation_space.sample()
        action, _ = self.policy.get_action(
            self.env.observation_space.flatten(obs))
        self.env.step(action)
    def test_get_action(self, obs_dim, action_dim):
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = GaussianMLPPolicy(env_spec=env.spec)

        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)
        actions, _ = policy.get_actions(
            [obs.flatten(), obs.flatten(),
             obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
 def test_get_action(self, hidden_channels, kernel_sizes, strides,
                     hidden_sizes):
     """Test get_action function."""
     env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
     policy = CategoricalCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   kernel_sizes=kernel_sizes,
                                   hidden_channels=hidden_channels,
                                   strides=strides,
                                   hidden_sizes=hidden_sizes)
     env.reset()
     obs = env.step(1).observation
     action, _ = policy.get_action(obs)
     assert env.action_space.contains(action)
    def test_build(self, obs_dim, action_dim):
        env = GymEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        qf = DiscreteMLPDuelingQFunction(env_spec=env.spec)
        env.reset()
        obs = env.step(1).observation

        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [obs]})

        input_var = tf.compat.v1.placeholder(tf.float32,
                                             shape=(None, ) + obs_dim)
        q_vals = qf.build(input_var, 'another')
        output2 = self.sess.run(q_vals, feed_dict={input_var: [obs]})

        assert np.array_equal(output1, output2)
 def test_flattened_image_input(self):
     env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
     gcb = GaussianCNNBaseline(env_spec=env.spec,
                               filters=((3, (3, 3)), (6, (3, 3))),
                               strides=(1, 1),
                               padding='SAME',
                               hidden_sizes=(32, ))
     env.reset()
     es = env.step(1)
     obs, rewards = es.observation, es.reward
     train_paths = [{'observations': [obs.flatten()], 'returns': [rewards]}]
     gcb.fit(train_paths)
     paths = {'observations': [obs.flatten()]}
     prediction = gcb.predict(paths)
     assert np.allclose(0., prediction)
示例#27
0
    def test_get_action_img_obs(self, hidden_channels, kernel_sizes, strides,
                                hidden_sizes):
        """Test get_action function with akro.Image observation space."""
        env = GymEnv(self._initialize_obs_env(DummyDiscretePixelEnv()),
                     is_image=True)
        policy = CategoricalCNNPolicy(env=env,
                                      kernel_sizes=kernel_sizes,
                                      hidden_channels=hidden_channels,
                                      strides=strides,
                                      hidden_sizes=hidden_sizes)
        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)
    def test_get_action(self, filters, strides, padding, hidden_sizes):
        env = GymEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=filters,
                                      strides=strides,
                                      padding=padding,
                                      hidden_sizes=hidden_sizes)

        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
    def test_get_action(self, kernel_sizes, hidden_channels, strides,
                        paddings):
        """Test get_action function."""
        env = GymEnv(DummyDiscretePixelEnv())
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(int(action[0]))
示例#30
0
    def test_build(self, obs_dim, action_dim):
        """Test build method"""
        env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs = env.step(1).observation

        obs_dim = env.spec.observation_space.flat_dim
        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, obs_dim))
        action_sym = policy.build(state_input, name='action_sym')

        action = self.sess.run(action_sym,
                               feed_dict={state_input: [obs.flatten()]})
        action = policy.action_space.unflatten(action)

        assert env.action_space.contains(action)