Ejemplo n.º 1
0
    def test_get_qval_sym(self, obs_dim, action_dim):
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()
        obs_ph, act_ph = qf.inputs

        output1 = qf.get_qval([obs], [act])

        input_var1 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, obs.shape[0]))
        input_var2 = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, act.shape[0]))
        q_vals = qf.get_qval_sym(input_var1, input_var2, 'another')
        output2 = self.sess.run(q_vals,
                                feed_dict={
                                    input_var1: [obs],
                                    input_var2: [act]
                                })

        expected_output = np.full((1, ), 0.5)

        assert np.array_equal(output1, output2)
        assert np.array_equal(output2[0], expected_output)
Ejemplo n.º 2
0
    def test_is_pickleable(self, obs_dim, action_dim):
        env = TfEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()
        obs_ph, act_ph = qf.inputs

        with tf.compat.v1.variable_scope(
                'ContinuousMLPQFunction/SimpleMLPMergeModel', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = qf.get_qval([obs], [act])

        h_data = pickle.dumps(qf)
        with tf.compat.v1.Session(graph=tf.Graph()):
            qf_pickled = pickle.loads(h_data)
            obs_ph_pickled, act_ph_pickled = qf_pickled.inputs
            output2 = qf_pickled.get_qval([obs], [act])

        assert np.array_equal(output1, output2)
    def test_q_vals(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])

        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_q_vals_goal_conditioned(self):
        env = MetaRLEnv(DummyDictEnv())
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate(
            (obs['observation'], obs['desired_goal'], obs['achieved_goal']),
            axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_output_shape(self, obs_dim, action_dim):
        env = MetaRLEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('metarl.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = obs.flatten()
        act = np.full(action_dim, 0.5).flatten()

        outputs = qf.get_qval([obs], [act])

        assert outputs.shape == (1, 1)