예제 #1
0
    def test_get_action(self, obs_dim, task_num, latent_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        embedding_spec = InOutSpec(
            input_space=akro.Box(low=np.zeros(task_num),
                                 high=np.ones(task_num)),
            output_space=akro.Box(low=np.zeros(latent_dim),
                                  high=np.ones(latent_dim)))
        encoder = GaussianMLPEncoder(embedding_spec)
        policy = GaussianMLPTaskEmbeddingPolicy(env_spec=env.spec,
                                                encoder=encoder)

        env.reset()
        obs, _, _, _ = env.step(1)
        latent = np.random.random((latent_dim, ))
        task = np.zeros(task_num)
        task[0] = 1

        action1, _ = policy.get_action_given_latent(obs, latent)
        action2, _ = policy.get_action_given_task(obs, task)
        action3, _ = policy.get_action(np.concatenate([obs.flatten(), task]))

        assert env.action_space.contains(action1)
        assert env.action_space.contains(action2)
        assert env.action_space.contains(action3)

        obses, latents, tasks = [obs] * 3, [latent] * 3, [task] * 3
        aug_obses = [np.concatenate([obs.flatten(), task])] * 3
        action1n, _ = policy.get_actions_given_latents(obses, latents)
        action2n, _ = policy.get_actions_given_tasks(obses, tasks)
        action3n, _ = policy.get_actions(aug_obses)

        for action in chain(action1n, action2n, action3n):
            assert env.action_space.contains(action)
    def test_get_qval_sym(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        output1 = qf.get_qval([obs], [act])

        input_var1 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, obs.shape[0]))
        input_var2 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, act.shape[0]))
        q_vals = qf.get_qval_sym(input_var1, input_var2, 'another')
        output2 = self.sess.run(q_vals,
                                feed_dict={
                                    input_var1: [obs],
                                    input_var2: [act]
                                })

        expected_output = np.full((1, ), 0.5)

        assert np.array_equal(output1, output2)
        assert np.array_equal(output2[0], expected_output)
class TestQfDerivedPolicy(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = MetaRLEnv(DummyDiscreteEnv())
        self.qf = SimpleQFunction(self.env.spec)
        self.policy = DiscreteQfDerivedPolicy(env_spec=self.env.spec,
                                              qf=self.qf)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    def test_discrete_qf_derived_policy(self):
        obs, _, _, _ = self.env.step(1)
        action, _ = self.policy.get_action(obs)
        assert self.env.action_space.contains(action)
        actions, _ = self.policy.get_actions([obs])
        for action in actions:
            assert self.env.action_space.contains(action)

    def test_is_pickleable(self):
        with tf.compat.v1.variable_scope('SimpleQFunction/SimpleMLPModel',
                                         reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())
        obs, _, _, _ = self.env.step(1)
        action1, _ = self.policy.get_action(obs)

        p = pickle.dumps(self.policy)
        with tf.compat.v1.Session(graph=tf.Graph()):
            policy_pickled = pickle.loads(p)
            action2, _ = policy_pickled.get_action(obs)
            assert action1 == action2
    def test_is_pickleable(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        with tf.compat.v1.variable_scope(
                'ContinuousMLPQFunction/SimpleMLPMergeModel', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = qf.get_qval([obs], [act])

        h_data = pickle.dumps(qf)
        with tf.compat.v1.Session(graph=tf.Graph()):
            qf_pickled = pickle.loads(h_data)
            output2 = qf_pickled.get_qval([obs], [act])

        assert np.array_equal(output1, output2)
    def test_is_pickleable(self, obs_dim, action_dim):
        """Test if ContinuousMLPPolicy is pickleable"""
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'continuous_mlp_policy.MLPModel'),
                        new=SimpleMLPModel):
            policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        with tf.compat.v1.variable_scope('ContinuousMLPPolicy/MLPModel',
                                         reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())
        output1 = self.sess.run(
            policy.model.outputs,
            feed_dict={policy.model.input: [obs.flatten()]})

        p = pickle.dumps(policy)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            output2 = sess.run(
                policy_pickled.model.outputs,
                feed_dict={policy_pickled.model.input: [obs.flatten()]})
            assert np.array_equal(output1, output2)
예제 #6
0
 def test_time_limit_env(self):
     metarl_env = MetaRLEnv(env_name='Pendulum-v0')
     metarl_env.reset()
     for _ in range(200):
         _, _, done, info = metarl_env.step(
             metarl_env.spec.action_space.sample())
     assert not done and info['TimeLimit.truncated']
     assert info['MetaRLEnv.TimeLimitTerminated']
    def test_output_shape(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        outputs = qf.get_qval([obs], [act])

        assert outputs.shape == (1, 1)
    def test_get_action_state_include_action(self, obs_dim, action_dim,
                                             hidden_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[
                None, None,
                env.observation_space.flat_dim + np.prod(action_dim)
            ],
            name='obs')
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_dim=hidden_dim,
                                   state_include_action=True)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        policy.reset()

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
예제 #9
0
    def test_get_qval_max_pooling(self, filters, strides, pool_strides,
                                  pool_shapes):
        env = MetaRLEnv(DummyDiscretePixelEnv())
        obs = env.reset()

        with mock.patch(('metarl.tf.models.'
                         'cnn_mlp_merge_model.CNNModelWithMaxPooling'),
                        new=SimpleCNNModelWithMaxPooling):
            with mock.patch(('metarl.tf.models.'
                             'cnn_mlp_merge_model.MLPMergeModel'),
                            new=SimpleMLPMergeModel):
                qf = ContinuousCNNQFunction(env_spec=env.spec,
                                            filters=filters,
                                            strides=strides,
                                            max_pooling=True,
                                            pool_strides=pool_strides,
                                            pool_shapes=pool_shapes)

        action_dim = env.action_space.shape

        obs, _, _, _ = env.step(1)

        act = np.full(action_dim, 0.5)
        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])

        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])

        for output in outputs:
            assert np.array_equal(output, expected_output)
예제 #10
0
    def test_is_pickleable(self):
        env = MetaRLEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalLSTMPolicy(env_spec=env.spec,
                                       state_include_action=False)

        policy.build(obs_var)
        policy.reset()
        obs = env.reset()

        policy.model._lstm_cell.weights[0].load(
            tf.ones_like(policy.model._lstm_cell.weights[0]).eval())

        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            output2 = sess.run([policy_pickled.distribution.probs],
                               feed_dict={
                                   policy_pickled.model.input:
                                   [[obs.flatten()], [obs.flatten()]]
                               })  # noqa: E126
            assert np.array_equal(output1, output2)
    def test_is_pickleable(self, obs_dim, action_dim):
        env = MetaRLEnv(
            DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalMLPPolicy(env_spec=env.spec)

        policy.build(obs_var)
        obs = env.reset()

        with tf.compat.v1.variable_scope(
                'CategoricalMLPPolicy/CategoricalMLPModel', reuse=True):
            bias = tf.compat.v1.get_variable('mlp/hidden_0/bias')
        # assign it to all one
        bias.load(tf.ones_like(bias).eval())
        output1 = self.sess.run(
            [policy.distribution.probs],
            feed_dict={policy.model.input: [[obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            output2 = sess.run(
                [policy_pickled.distribution.probs],
                feed_dict={policy_pickled.model.input: [[obs.flatten()]]})
            assert np.array_equal(output1, output2)
    def test_get_embedding(self, obs_dim, embedding_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=embedding_dim))
        embedding_spec = InOutSpec(input_space=env.spec.observation_space,
                                   output_space=env.spec.action_space)
        embedding = GaussianMLPEncoder(embedding_spec)
        task_input = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, None,
                                                     embedding.input_dim))
        embedding.build(task_input)

        env.reset()
        obs, _, _, _ = env.step(1)

        latent, _ = embedding.get_latent(obs)
        latents, _ = embedding.get_latents([obs] * 5)
        assert env.action_space.contains(latent)
        for latent in latents:
            assert env.action_space.contains(latent)
    def test_is_pickleable(self):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=(1, ), action_dim=(1, )))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   state_include_action=False)

        policy.build(obs_var)
        env.reset()
        obs = env.reset()
        with tf.compat.v1.variable_scope('GaussianGRUPolicy/GaussianGRUModel',
                                         reuse=True):
            param = tf.compat.v1.get_variable(
                'dist_params/log_std_param/parameter')
        # assign it to all one
        param.load(tf.ones_like(param).eval())

        output1 = self.sess.run(
            [policy.distribution.loc,
             policy.distribution.stddev()],
            feed_dict={policy.model.input: [[obs.flatten()], [obs.flatten()]]})

        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=[None, None, env.observation_space.flat_dim],
                name='obs')
            policy_pickled.build(obs_var)
            # yapf: disable
            output2 = sess.run(
                [
                    policy_pickled.distribution.loc,
                    policy_pickled.distribution.stddev()
                ],
                feed_dict={
                    policy_pickled.model.input: [[obs.flatten()],
                                                 [obs.flatten()]]
                })
            assert np.array_equal(output1, output2)
예제 #14
0
    def test_get_action(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = GaussianMLPPolicy(env_spec=env.spec)

        policy.build(obs_var)
        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)
        actions, _ = policy.get_actions(
            [obs.flatten(), obs.flatten(),
             obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
    def test_q_vals(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])

        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_q_vals_goal_conditioned(self):
        env = MetaRLEnv(DummyDictEnv())
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate(
            (obs['observation'], obs['desired_goal'], obs['achieved_goal']),
            axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_is_pickleable(self, obs_dim, embedding_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=embedding_dim))
        embedding_spec = InOutSpec(input_space=env.spec.observation_space,
                                   output_space=env.spec.action_space)
        embedding = GaussianMLPEncoder(embedding_spec)
        task_input = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, None,
                                                     embedding.input_dim))
        embedding.build(task_input, name='default')

        env.reset()
        obs, _, _, _ = env.step(1)
        obs_dim = env.spec.observation_space.flat_dim

        with tf.compat.v1.variable_scope('GaussianMLPEncoder/GaussianMLPModel',
                                         reuse=True):
            bias = tf.compat.v1.get_variable(
                'dist_params/mean_network/hidden_0/bias')
        # assign it to all one
        bias.load(tf.ones_like(bias).eval())
        output1 = self.sess.run(
            [embedding.distribution.loc,
             embedding.distribution.stddev()],
            feed_dict={embedding.model.input: [[obs.flatten()]]})

        p = pickle.dumps(embedding)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            embedding_pickled = pickle.loads(p)
            task_input = tf.compat.v1.placeholder(
                tf.float32, shape=(None, None, embedding_pickled.input_dim))
            embedding_pickled.build(task_input, name='default')

            output2 = sess.run(
                [
                    embedding_pickled.distribution.loc,
                    embedding_pickled.distribution.stddev()
                ],
                feed_dict={embedding_pickled.model.input: [[obs.flatten()]]})
            assert np.array_equal(output1, output2)
    def test_get_action(self, filters, strides, padding, hidden_sizes):
        env = MetaRLEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=filters,
                                      strides=strides,
                                      padding=padding,
                                      hidden_sizes=hidden_sizes)
        obs_var = tf.compat.v1.placeholder(tf.float32,
                                           shape=(None, None) +
                                           env.observation_space.shape,
                                           name='obs')
        policy.build(obs_var)

        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
    def test_is_pickleable(self):
        env = MetaRLEnv(DummyDiscretePixelEnv())
        policy = CategoricalCNNPolicy(env_spec=env.spec,
                                      filters=((3, (32, 32)), ),
                                      strides=(1, ),
                                      padding='SAME',
                                      hidden_sizes=(4, ))
        obs_var = tf.compat.v1.placeholder(tf.float32,
                                           shape=(None, None) +
                                           env.observation_space.shape,
                                           name='obs')
        policy.build(obs_var)

        env.reset()
        obs, _, _, _ = env.step(1)

        with tf.compat.v1.variable_scope(
                'CategoricalCNNPolicy/CategoricalCNNModel', reuse=True):
            cnn_bias = tf.compat.v1.get_variable('CNNModel/cnn/h0/bias')
            bias = tf.compat.v1.get_variable('MLPModel/mlp/hidden_0/bias')

        cnn_bias.load(tf.ones_like(cnn_bias).eval())
        bias.load(tf.ones_like(bias).eval())

        output1 = self.sess.run(policy.distribution.probs,
                                feed_dict={policy.model.input: [[obs]]})
        p = pickle.dumps(policy)

        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            policy_pickled = pickle.loads(p)
            obs_var = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, None) +
                                               env.observation_space.shape,
                                               name='obs')
            policy_pickled.build(obs_var)
            output2 = sess.run(policy_pickled.distribution.probs,
                               feed_dict={policy_pickled.model.input: [[obs]]})
            assert np.array_equal(output1, output2)
    def test_get_action(self, obs_dim, action_dim):
        """Test get_action method"""
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'continuous_mlp_policy.MLPModel'),
                        new=SimpleMLPModel):
            policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        action, _ = policy.get_action(obs.flatten())

        expected_action = np.full(action_dim, 0.5)

        assert env.action_space.contains(action)
        assert np.array_equal(action, expected_action)

        actions, _ = policy.get_actions(
            [obs.flatten(), obs.flatten(),
             obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
            assert np.array_equal(action, expected_action)
    def test_get_action_sym(self, obs_dim, action_dim):
        """Test get_action_sym method"""
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.policies.'
                         'continuous_mlp_policy.MLPModel'),
                        new=SimpleMLPModel):
            policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs, _, _, _ = env.step(1)

        obs_dim = env.spec.observation_space.flat_dim
        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, obs_dim))
        action_sym = policy.get_action_sym(state_input, name='action_sym')

        expected_action = np.full(action_dim, 0.5)

        action = self.sess.run(action_sym,
                               feed_dict={state_input: [obs.flatten()]})
        action = policy.action_space.unflatten(action)

        assert np.array_equal(action, expected_action)
        assert env.action_space.contains(action)
예제 #22
0
    def test_get_qval(self, filters, strides):
        env = MetaRLEnv(DummyDiscretePixelEnv())
        obs = env.reset()

        with mock.patch(('metarl.tf.models.'
                         'cnn_mlp_merge_model.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.models.'
                             'cnn_mlp_merge_model.MLPMergeModel'),
                            new=SimpleMLPMergeModel):
                qf = ContinuousCNNQFunction(env_spec=env.spec,
                                            filters=filters,
                                            strides=strides)

        action_dim = env.action_space.shape

        obs, _, _, _ = env.step(1)

        act = np.full(action_dim, 0.5)
        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])

        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])

        for output in outputs:
            assert np.array_equal(output, expected_output)

        # make sure observations are unflattened

        obs = env.observation_space.flatten(obs)
        qf._f_qval = mock.MagicMock()

        qf.get_qval([obs], [act])
        unflattened_obs = qf._f_qval.call_args_list[0][0][0]
        assert unflattened_obs.shape[1:] == env.spec.observation_space.shape

        qf.get_qval([obs, obs], [act, act])
        unflattened_obs = qf._f_qval.call_args_list[1][0][0]
        assert unflattened_obs.shape[1:] == env.spec.observation_space.shape
예제 #23
0
    def test_get_action(self, obs_dim, action_dim, hidden_dim):
        env = MetaRLEnv(
            DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        obs_var = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, None, env.observation_space.flat_dim],
            name='obs')
        policy = CategoricalGRUPolicy(env_spec=env.spec,
                                      hidden_dim=hidden_dim,
                                      state_include_action=False)
        policy.build(obs_var)
        policy.reset(do_resets=None)
        obs = env.reset()

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs.flatten()])
        for action in actions:
            assert env.action_space.contains(action)
예제 #24
0
    def test_is_pickleable(self, filters, strides):

        env = MetaRLEnv(DummyDiscretePixelEnv())
        obs = env.reset()

        with mock.patch(('metarl.tf.models.'
                         'cnn_mlp_merge_model.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.models.'
                             'cnn_mlp_merge_model.MLPMergeModel'),
                            new=SimpleMLPMergeModel):
                qf = ContinuousCNNQFunction(env_spec=env.spec,
                                            filters=filters,
                                            strides=strides)

        action_dim = env.action_space.shape

        obs, _, _, _ = env.step(1)
        act = np.full(action_dim, 0.5)
        _, _ = qf.inputs

        with tf.compat.v1.variable_scope(
                'ContinuousCNNQFunction/CNNMLPMergeModel/SimpleMLPMergeModel',
                reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = qf.get_qval([obs], [act])

        h_data = pickle.dumps(qf)
        with tf.compat.v1.Session(graph=tf.Graph()):
            qf_pickled = pickle.loads(h_data)
            _, _ = qf_pickled.inputs
            output2 = qf_pickled.get_qval([obs], [act])

        assert np.array_equal(output1, output2)
예제 #25
0
    def test_get_qval_sym(self, filters, strides):
        env = MetaRLEnv(DummyDiscretePixelEnv())
        obs = env.reset()

        with mock.patch(('metarl.tf.models.'
                         'cnn_mlp_merge_model.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.models.'
                             'cnn_mlp_merge_model.MLPMergeModel'),
                            new=SimpleMLPMergeModel):
                qf = ContinuousCNNQFunction(env_spec=env.spec,
                                            filters=filters,
                                            strides=strides)
        action_dim = env.action_space.shape

        obs, _, _, _ = env.step(1)
        act = np.full(action_dim, 0.5)

        output1 = qf.get_qval([obs], [act])

        input_var1 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, ) + obs.shape)
        input_var2 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, ) + act.shape)
        q_vals = qf.get_qval_sym(input_var1, input_var2, 'another')

        output2 = self.sess.run(q_vals,
                                feed_dict={
                                    input_var1: [obs],
                                    input_var2: [act]
                                })

        expected_output = np.full((1, ), 0.5)

        assert np.array_equal(output1, output2)
        assert np.array_equal(output2[0], expected_output)
예제 #26
0
class TestNormalizedGym:
    def setup_method(self):
        self.env = MetaRLEnv(
            normalize(gym.make('Pendulum-v0'),
                      normalize_reward=True,
                      normalize_obs=True,
                      flatten_obs=True))

    def teardown_method(self):
        self.env.close()

    def test_does_not_modify_action(self):
        a = self.env.action_space.sample()
        a_copy = a
        self.env.reset()
        self.env.step(a)
        assert a == a_copy

    def test_flatten(self):
        for _ in range(10):
            self.env.reset()
            for _ in range(5):
                self.env.render()
                action = self.env.action_space.sample()
                next_obs, _, done, _ = self.env.step(action)
                assert next_obs.shape == self.env.observation_space.low.shape
                if done:
                    break

    def test_unflatten(self):
        for _ in range(10):
            self.env.reset()
            for _ in range(5):
                action = self.env.action_space.sample()
                next_obs, _, done, _ = self.env.step(action)
                # yapf: disable
                assert (self.env.observation_space.flatten(next_obs).shape
                        == self.env.observation_space.flat_dim)
                # yapf: enable
                if done:
                    break
 def test_normalize_pixel_batch(self):
     env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=True)
     obs = env.reset()
     obs_normalized = normalize_pixel_batch(obs)
     expected = [ob / 255.0 for ob in obs]
     assert np.allclose(obs_normalized, expected)
class TestDiscreteCNNQFunction(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = MetaRLEnv(DummyDiscretePixelEnv())
        self.obs = self.env.reset()

    # yapf: disable
    @pytest.mark.parametrize('filters, strides', [
        (((5, (3, 3)), ), (1, )),
        (((5, (3, 3)), ), (2, )),
        (((5, (3, 3)), (5, (3, 3))), (1, 1)),
    ])
    # yapf: enable
    def test_get_action(self, filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          dueling=False)

        action_dim = self.env.action_space.n
        expected_output = np.full(action_dim, 0.5)
        outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})
        assert np.array_equal(outputs[0], expected_output)
        outputs = self.sess.run(
            qf.q_vals, feed_dict={qf.input: [self.obs, self.obs, self.obs]})
        for output in outputs:
            assert np.array_equal(output, expected_output)

    @pytest.mark.parametrize('obs_dim', [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]])
    def test_invalid_obs_shape(self, obs_dim):
        boxEnv = MetaRLEnv(DummyDiscreteEnv(obs_dim=obs_dim))
        with pytest.raises(ValueError):
            DiscreteCNNQFunction(env_spec=boxEnv.spec,
                                 filters=((5, (3, 3)), ),
                                 strides=(2, ),
                                 dueling=False)

    def test_obs_is_image(self):
        image_env = MetaRLEnv(DummyDiscretePixelEnv(), is_image=True)
        with mock.patch(('metarl.tf.models.'
                         'categorical_cnn_model.CNNModel._build'),
                        autospec=True,
                        side_effect=CNNModel._build) as build:

            qf = DiscreteCNNQFunction(env_spec=image_env.spec,
                                      filters=((5, (3, 3)), ),
                                      strides=(2, ),
                                      dueling=False)
            normalized_obs = build.call_args_list[0][0][1]

            input_ph = qf.input
            assert input_ph != normalized_obs

            fake_obs = [np.full(image_env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={input_ph: fake_obs}) == 1.).all()

            obs_dim = image_env.spec.observation_space.shape
            state_input = tf.compat.v1.placeholder(tf.uint8,
                                                   shape=(None, ) + obs_dim)

            qf.get_qval_sym(state_input, name='another')
            normalized_obs = build.call_args_list[1][0][1]

            fake_obs = [np.full(image_env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={state_input:
                                             fake_obs}) == 1.).all()

    def test_obs_not_image(self):
        env = self.env
        with mock.patch(('metarl.tf.models.'
                         'categorical_cnn_model.CNNModel._build'),
                        autospec=True,
                        side_effect=CNNModel._build) as build:

            qf = DiscreteCNNQFunction(env_spec=env.spec,
                                      filters=((5, (3, 3)), ),
                                      strides=(2, ),
                                      dueling=False)
            normalized_obs = build.call_args_list[0][0][1]

            input_ph = qf.input
            assert input_ph == normalized_obs

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={input_ph:
                                             fake_obs}) == 255.).all()

            obs_dim = env.spec.observation_space.shape
            state_input = tf.compat.v1.placeholder(tf.float32,
                                                   shape=(None, ) + obs_dim)

            qf.get_qval_sym(state_input, name='another')
            normalized_obs = build.call_args_list[1][0][1]

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={state_input:
                                             fake_obs}) == 255).all()

    # yapf: disable
    @pytest.mark.parametrize('filters, strides', [
        (((5, (3, 3)), ), (1, )),
        (((5, (3, 3)), ), (2, )),
        (((5, (3, 3)), (5, (3, 3))), (1, 1)),
    ])
    # yapf: enable
    def test_get_action_dueling(self, filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPDuelingModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          dueling=True)

        action_dim = self.env.action_space.n
        expected_output = np.full(action_dim, 0.5)
        outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})
        assert np.array_equal(outputs[0], expected_output)
        outputs = self.sess.run(
            qf.q_vals, feed_dict={qf.input: [self.obs, self.obs, self.obs]})
        for output in outputs:
            assert np.array_equal(output, expected_output)

    # yapf: disable
    @pytest.mark.parametrize('filters, strides, pool_strides, pool_shapes', [
        (((5, (3, 3)), ), (1, ), (1, 1), (1, 1)),  # noqa: E122
        (((5, (3, 3)), ), (2, ), (2, 2), (2, 2)),  # noqa: E122
        (((5, (3, 3)), (5, (3, 3))), (1, 1), (1, 1), (1, 1)),  # noqa: E122
        (((5, (3, 3)), (5, (3, 3))), (1, 1), (2, 2), (2, 2))  # noqa: E122
    ])  # noqa: E122
    # yapf: enable
    def test_get_action_max_pooling(self, filters, strides, pool_strides,
                                    pool_shapes):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModelWithMaxPooling'),
                        new=SimpleCNNModelWithMaxPooling):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          max_pooling=True,
                                          pool_strides=pool_strides,
                                          pool_shapes=pool_shapes,
                                          dueling=False)

        action_dim = self.env.action_space.n
        expected_output = np.full(action_dim, 0.5)
        outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})
        assert np.array_equal(outputs[0], expected_output)
        outputs = self.sess.run(
            qf.q_vals, feed_dict={qf.input: [self.obs, self.obs, self.obs]})
        for output in outputs:
            assert np.array_equal(output, expected_output)

    # yapf: disable
    @pytest.mark.parametrize('filters, strides', [
        (((5, (3, 3)), ), (1, )),
        (((5, (3, 3)), ), (2, )),
        (((5, (3, 3)), (5, (3, 3))), (1, 1)),
    ])
    # yapf: enable
    def test_get_qval_sym(self, filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          dueling=False)
        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})

        obs_dim = self.env.observation_space.shape
        action_dim = self.env.action_space.n

        input_var = tf.compat.v1.placeholder(tf.float32,
                                             shape=(None, ) + obs_dim)
        q_vals = qf.get_qval_sym(input_var, 'another')
        output2 = self.sess.run(q_vals, feed_dict={input_var: [self.obs]})

        expected_output = np.full(action_dim, 0.5)

        assert np.array_equal(output1, output2)
        assert np.array_equal(output2[0], expected_output)

    # yapf: disable
    @pytest.mark.parametrize('filters, strides', [
        (((5, (3, 3)), ), (1, )),
        (((5, (3, 3)), ), (2, )),
        (((5, (3, 3)), (5, (3, 3))), (1, 1)),
    ])
    # yapf: enable
    def test_is_pickleable(self, filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          dueling=False)
        with tf.compat.v1.variable_scope(
                'DiscreteCNNQFunction/Sequential/SimpleMLPModel', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})

        h_data = pickle.dumps(qf)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            qf_pickled = pickle.loads(h_data)
            output2 = sess.run(qf_pickled.q_vals,
                               feed_dict={qf_pickled.input: [self.obs]})

        assert np.array_equal(output1, output2)

    # yapf: disable
    @pytest.mark.parametrize('filters, strides', [
        (((5, (3, 3)), ), (1, )),
        (((5, (3, 3)), ), (2, )),
        (((5, (3, 3)), (5, (3, 3))), (1, 1)),
    ])
    # yapf: enable
    def test_clone(self, filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filters=filters,
                                          strides=strides,
                                          dueling=False)
        qf_clone = qf.clone('another_qf')
        assert qf_clone._filters == qf._filters
        assert qf_clone._strides == qf._strides